/*
 * dgrambuf.c - C datagrams buffer.
 *
 * Copyright 2016 by Ludovic Pouzenc <ludovic@pouzenc.fr>
 */
#define _GNU_SOURCE /* See feature_test_macros(7) */
#include "dgrambuf.h"

#include "config.h"

#include <sys/socket.h> /* recvmmsg() _GNU_SOURCE */
#include <stdlib.h> /* calloc(), free() */
#include <stdio.h> /* perror() */
#include <errno.h> /* errno */
#include <string.h> /* memset() */
#include <sys/uio.h> /* writev() */
#include <stdint.h> /* uint8_t, uint64_t */
#include <signal.h> /* sigaction() */
#include <unistd.h> /* alarm() */
#include <limits.h> /* SSIZE_MAX */
#include "gl_rbtree_list.h" /* Red-Black Tree backed Sorted list, gnulib-tool --import rbtree-list */

struct indexed_uint {
	size_t index;
	unsigned int value;
};

struct dgrambuf_stats_t {
	uint64_t dgrambuf_read_on_full;
	uint64_t recvmmsg_calls, recv_dgrams, recv_byte;
	uint64_t dgram_invalid, dgram_past, dgram_future, dgram_dup, dgram_end_marker;
	uint64_t writev_calls, write_partial, write_byte;
};

struct dgrambuf_t {
	/* dgram validation after receive, takes dgram len and a pointer to the start of dgram data
	   Must returns dgram seq number or 0 if invalid dgram */
	int (*validate_func)(unsigned int, void *, unsigned int*);

	struct dgrambuf_stats_t stats;
	struct sigaction sa_sigalrm;

	size_t dgram_slots;
	size_t dgram_max_size;
	size_t dgram_header_size;

	size_t iovec_slots;
	struct mmsghdr *msgs;
	struct iovec *iov_recv;
	struct iovec *iov_write; /* malloc'ed array */

	struct iovec *partial_write_iov; /* Pointer to an item of iov_write[] */
	size_t partial_write_remaining_iovcnt;
	size_t partial_write_remaining_bytes;

	unsigned int dgram_seq_last;
	unsigned int dgram_seq_base;
	unsigned int *dgram_len;

	struct indexed_uint *dgram_slot_seq; /* malloc'ed array */
	struct indexed_uint **dgram_read_active_slots; /* malloc'd array of pointers to items of dgram_slot_seq[] */
	size_t dgram_read_active_slots_count;
	struct indexed_uint **dgram_write_active_slots; /* malloc'd array of pointers to items of dgram_slot_seq[] */
	size_t dgram_write_active_slots_count;

	gl_list_t dgram_empty_slots;
	gl_list_t dgram_used_slots;

	uint8_t *buf; /* malloc-ed 2d byte array : buf[dgram_slots][dgram_max_size] */
};

void _sigalrm_handler(int signum);
int  _compare_indexed_uint(const void *pa, const void *pb);
bool _equals_indexed_uint(const void *pa, const void *pb);
void _update_ordered_seq_numbers(dgrambuf_t dbuf);

#ifndef HAVE_MIN_SIZE_T
size_t min_size_t(size_t a, size_t b) { return (a<b)?a:b; }
#endif /*HAVE_MIN_SIZE_T*/

void dgrambuf_set_validate_func(dgrambuf_t dbuf, int (*validate_func)(unsigned int, void *, unsigned int *)) {
	dbuf->validate_func = validate_func;
}

size_t dgrambuf_get_free_count(const dgrambuf_t dbuf) {
	return gl_list_size(dbuf->dgram_empty_slots);
}

size_t dgrambuf_get_used_count(const dgrambuf_t dbuf) {
	return gl_list_size(dbuf->dgram_used_slots);
}

ssize_t dgrambuf_recvmmsg(dgrambuf_t dbuf, int sockfd, int timeout, int *info) {
	uint8_t *dgram_base;
	ssize_t recv_byte;
	size_t i, dgram_index, recv_msg_count, free_count;
	int res;
	unsigned int seq, dgram_len;
	struct sigaction sa_old;
	struct indexed_uint *active_slot;
	gl_list_node_t pos;


	/* Info ptr is mandatory */
	*info = 0;

	/* Validate function is mandatory */
	if ( !dbuf->validate_func ) {
		return -3;
	}

	/* Buffer is full, can't receive */
	free_count = dgrambuf_get_free_count(dbuf);
	if ( free_count == 0 ) {
		dbuf->stats.dgrambuf_read_on_full++;
		*info |= DGRAMBUF_RECV_OVERWRITE;
		/*FIXME : this locks everything if buf full + next seq missing*/
		return 0;
	}

	/* Initialize recvmmsg() syscall arguments and keep track of active slots */
	for (i=0; i < dbuf->iovec_slots && i < free_count; i++) {
		/* Pop a free slot, ignoring const modifier from gl_list_get_at() */
		dbuf->dgram_read_active_slots[i] = (struct indexed_uint *) gl_list_get_at(dbuf->dgram_empty_slots, 0);
		gl_sortedlist_remove(dbuf->dgram_empty_slots, _compare_indexed_uint, dbuf->dgram_read_active_slots[i]);

		dgram_index = dbuf->dgram_read_active_slots[i]->index;
		dbuf->iov_recv[i].iov_base = dbuf->buf + dgram_index * dbuf->dgram_max_size;
		dbuf->iov_recv[i].iov_len = dbuf->dgram_max_size;

		memset(dbuf->msgs + i, 0, sizeof(struct mmsghdr));
		dbuf->msgs[i].msg_hdr.msg_iov = dbuf->iov_recv + i;
		dbuf->msgs[i].msg_hdr.msg_iovlen = 1;
	}
	dbuf->dgram_read_active_slots_count = i;

	/* Do the syscall with alarm() to circumvent bad behavior in recvmmsg(2) timeout */
	if (timeout) {
		sigaction(SIGALRM, &(dbuf->sa_sigalrm), &sa_old);
		alarm(timeout);
	}
	res = recvmmsg(sockfd, dbuf->msgs, dbuf->dgram_read_active_slots_count, MSG_WAITFORONE, NULL);
	if (timeout) {
		alarm(0);
		sigaction(SIGALRM, &sa_old, NULL);
	}
	dbuf->stats.recvmmsg_calls++;

	if (res < 0) {
		if ( errno == EINTR ) {
			recv_msg_count = 0;
			*info |= DGRAMBUF_RECV_EINTR;
		} else {
			perror("recvmmsg()");
			return -1;
		}
	} else {
		recv_msg_count = res;
	}

	if (recv_msg_count > 0) {
		dbuf->stats.recv_dgrams += recv_msg_count;
		if ( recv_msg_count == dbuf->dgram_read_active_slots_count ) {
			*info |= DGRAMBUF_RECV_IOVEC_FULL;
		}
	}

	/* Check all received messages */
	for (i=0, recv_byte=0; i<recv_msg_count; i++) {
		active_slot = dbuf->dgram_read_active_slots[i];
		dgram_base = dbuf->iov_recv[i].iov_base;
		dgram_len = dbuf->msgs[i].msg_len;

		/* dgrambuf_new() adjust iovec_len to prevent overflows on ssize_t*/
		recv_byte += dgram_len;

		res = dbuf->validate_func(dgram_len, dgram_base, &seq);
		switch (res) {
			case 1:
				if ( seq < dbuf->dgram_seq_base ) {
					fprintf(stderr, "dgrambuf_recvmmsg(): #%zu past (%u)\n", i, seq);
					dbuf->stats.dgram_past++;
				} else if ( seq >= dbuf->dgram_seq_base + dbuf->dgram_slots ) {
					fprintf(stderr, "dgrambuf_recvmmsg(): #%zu future (%u)\n", i, seq);
					dbuf->stats.dgram_future++;
					*info |= DGRAMBUF_RECV_FUTURE_DGRAM;
				} else {
					active_slot->value = seq;
					pos = gl_sortedlist_search(dbuf->dgram_used_slots, _compare_indexed_uint, active_slot);
					if ( pos != NULL ) {
						fprintf(stderr, "dgrambuf_recvmmsg(): #%zu duplicate (%u)\n", i, seq);
						dbuf->stats.dgram_dup++;
						*info |= DGRAMBUF_RECV_DUPLICATE_DGRAM;
					} else {
						/*fprintf(stderr, "dgrambuf_recvmmsg(): #%zu valid (%u)\n", i, seq);*/
						pos = gl_sortedlist_nx_add(dbuf->dgram_used_slots, _compare_indexed_uint, active_slot);
						if ( pos == NULL ) /*TODO: better oom handling */
							return -4;
						dbuf->dgram_len[active_slot->index] = dgram_len;
						*info |= DGRAMBUF_RECV_VALID_DGRAM;
						continue;
					}
				}
				break;
			case 2:
				fprintf(stderr, "dgrambuf_recvmmsg(): #%zu finalize (%u)\n", i, seq);
				dbuf->stats.dgram_end_marker++;
				dbuf->dgram_seq_last = seq;
				*info |= DGRAMBUF_RECV_FINALIZE;
				break;
			default:
				fprintf(stderr, "dgrambuf_recvmmsg(): #%zu invalid\n", i);
				dbuf->stats.dgram_invalid++;
				break;
		}
		/* In all invalid dgram cases, put back active_slot in dgram_free_slots */
		pos = gl_sortedlist_nx_add(dbuf->dgram_empty_slots, _compare_indexed_uint, active_slot);
		if ( !pos ) /*TODO: better oom handling */
			return -4;
	}

	/* Push remaining active slots in dgram_empty_slots */
	for (/*next i*/; i < dbuf->dgram_read_active_slots_count; i++) {
		active_slot = dbuf->dgram_read_active_slots[i];
		pos = gl_sortedlist_nx_add(dbuf->dgram_empty_slots, _compare_indexed_uint, active_slot);
		if ( !pos ) /*TODO: better oom handling */
			return -4;
	}

	dbuf->dgram_read_active_slots_count = 0;
	dbuf->stats.recv_byte += recv_byte;

	return recv_byte;
}

int dgrambuf_have_data_ready_to_write(dgrambuf_t dbuf) {
	unsigned int next_dgram_seq;

	/* Last write was partial, so there is more to write */
	if ( dbuf->partial_write_remaining_bytes ) {
		return 1;
	}

	/* dgram_used_slots is empty, nothing to write */
	if ( dgrambuf_get_used_count(dbuf) == 0 ) {
		return 0;
	}

	/* Nothing to write if next dgram is not in buffer at all */
	next_dgram_seq = ((struct indexed_uint *) gl_list_get_at(dbuf->dgram_used_slots, 0))->value;
	/*fprintf(stderr, "DEBUG : dgram_seq_base==%u next_dgram_seq == %u\n", dbuf->dgram_seq_base, next_dgram_seq);*/
	if ( next_dgram_seq != dbuf->dgram_seq_base ) {
		return 0;
	}
	/* At least some data of one dgram is availble for writing out */
	return 1;
}

int dgrambuf_have_received_everything(dgrambuf_t dbuf) {
	/*FIXME : Really implement this */
	return dbuf->dgram_seq_last && ( dbuf->dgram_seq_base - 1 == dbuf->dgram_seq_last );
}

ssize_t dgrambuf_write(dgrambuf_t dbuf, int fd, int *info) {
	size_t dgram_index, i, vlen, total, len, remain, used_count;
	unsigned int curr_seq, prev_seq, dgram_len;
	ssize_t nwrite;
	struct iovec *iov;
	struct indexed_uint *active_slot;
	bool pos;

	/* FIXME Info ptr is mandatory */
	*info = 0;

	if ( dbuf->partial_write_remaining_bytes ) {
		/* Previous writev() was partial, continue it */
		iov = dbuf->partial_write_iov;
		vlen = dbuf->partial_write_remaining_iovcnt;
		total = dbuf->partial_write_remaining_bytes;
	} else if ( ! dgrambuf_have_data_ready_to_write(dbuf) ) {
		return 0; /* XXX Inline code ? */
	} else {
		/* Prepare a write batch, buffer state is in dgram_seq_numbers */
		iov = dbuf->iov_write;
		total = 0;

		/* Initialize iovecs for writev, take dgram payloads following the sequence numbers */
		prev_seq = 0;
		used_count = dgrambuf_get_used_count(dbuf);
		for (i = 0; i < dbuf->iovec_slots && i < used_count; i++) {
			/* Pop a used slot */
			dbuf->dgram_write_active_slots[i] = (struct indexed_uint *) gl_list_get_at(dbuf->dgram_used_slots, 0);
			gl_sortedlist_remove(dbuf->dgram_used_slots, _compare_indexed_uint, dbuf->dgram_write_active_slots[i]);
			dbuf->dgram_write_active_slots_count++;

			curr_seq = dbuf->dgram_write_active_slots[i]->value;

			/* Skip empty dgram slot */
			if ( curr_seq == 0 ) {
				fprintf(stderr, "Oops : found empty slot (i==%zu)\n", i);
				continue;
			}
			/* Skip if current dgram is a dup of the previous */
			if ( curr_seq == prev_seq ) {
				fprintf(stderr, "Oops : found duplicated dgram in buffer (%u)\n", curr_seq);
				continue;
			}
			/* Skip dgram comming from the past */
			if ( curr_seq < dbuf->dgram_seq_base ) {
				fprintf(stderr, "Oops : found dgram from past in buffer (%u)\n", curr_seq);
				continue;
			}
			/* Stop if current seq dgram is missing */
			if ( ( i > 0 ) && (curr_seq > prev_seq+1 ) ) {
				break;
			}
			/* Stop if first dgram to write is not in buffer at all */
			if ( ( i == 0 ) && (curr_seq != dbuf->dgram_seq_base) ) {
				fprintf(stderr, "Oops : nothing to write, missing %u seq\n", dbuf->dgram_seq_base);
				break;
			}

			/* Normal case : curr_seq is the next dgram to write */
			dgram_index = dbuf->dgram_write_active_slots[i]->index;
			dgram_len = dbuf->dgram_len[dgram_index] - dbuf->dgram_header_size;

			/* Setup iovecs */
			dbuf->iov_write[i].iov_len = dgram_len;
			dbuf->iov_write[i].iov_base = dbuf->buf
				+ dgram_index*dbuf->dgram_max_size + dbuf->dgram_header_size;

			/* Update counters */
			total += dgram_len;
			prev_seq = curr_seq;
			dbuf->dgram_seq_base = curr_seq + 1;
		}
		vlen = i;

		/* Nothing valid to write out (but buffer not empty, missing the next dgram) */
		if ( vlen == 0 ) {
			fprintf(stderr, "Oops : nothing to write at all\n");
			return -2;
		}

		if ( vlen == dbuf->iovec_slots ) {
			*info |= DGRAMBUF_WRITE_IOVEC_FULL;
		}
	}

	nwrite = writev(fd, iov, vlen);
	dbuf->stats.writev_calls++;
	if ( nwrite < 0 ) {
		/* Treat non fatal errors */
		if ( errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
			/* Keeps some state informations for retry */
			dbuf->partial_write_remaining_bytes = total;
			dbuf->partial_write_remaining_iovcnt = vlen;
			dbuf->partial_write_iov = iov;
			*info |= DGRAMBUF_WRITE_EWOULDBLOCK_OR_EINTR;
			return 0;
		}
		/* Print fatal errors and bail out */
		perror("writev()");
		return -1;
	}

	dbuf->partial_write_remaining_bytes = total - nwrite;
	if ( nwrite > 0 ) {
		dbuf->stats.write_byte += nwrite;
		*info |= DGRAMBUF_WRITE_SUCCESS;

		if ( dbuf->partial_write_remaining_bytes ) {
			/* If the write was partially done */
			*info |= DGRAMBUF_WRITE_PARTIAL;
			dbuf->stats.write_partial++;
			/* Find the partially written iov and update it */
			remain = nwrite;
			for (i=0; i<vlen; i++) {
				len = dbuf->iov_write[i].iov_len;
				if ( remain < len ) {
					dbuf->partial_write_remaining_iovcnt = vlen - i;
					if ( dbuf->partial_write_iov ) {
						dbuf->partial_write_iov += i;
					} else {
						dbuf->partial_write_iov = dbuf->iov_write + i;
					}

					dbuf->iov_write[i].iov_base =
						(uint8_t *) dbuf->iov_write[i].iov_base + remain;
					dbuf->iov_write[i].iov_len -= remain;
					break;
				}
				remain -= len;
			}
			if ( i == vlen ) {
				fprintf(stderr, "Fatal : failed to find partial iov after partial write\n");
				return -3;
			}

		} else {
			/* Full write has happened */
			for (i=0; i<dbuf->dgram_write_active_slots_count; i++) {
				active_slot = (struct indexed_uint *) dbuf->dgram_write_active_slots[i];
				active_slot->value = 0;
				pos = gl_sortedlist_nx_add(dbuf->dgram_empty_slots, _compare_indexed_uint, active_slot);
				if ( !pos ) /*TODO: better oom handling ? */
					return -4;
			}
			dbuf->dgram_write_active_slots_count = 0;
			/* Wipe outdated partial_* values */
			dbuf->partial_write_iov = NULL;
			dbuf->partial_write_remaining_iovcnt = 0;
		}
	}

	return nwrite;
}

int dgrambuf_stats(dgrambuf_t dbuf, char **allocated_string) {
	uint64_t dgram_pending = dgrambuf_get_used_count(dbuf);
	uint64_t dgram_missing = 0;
	if ( dbuf->dgram_seq_last ) {
		dgram_missing = dbuf->dgram_seq_last - (dbuf->dgram_seq_base - 1) - dgram_pending;
	}

	return asprintf(allocated_string,
		"dgrambuf_read_on_full==%d "
		"recvmmsg_calls==%d, recv_dgrams==%d, recv_byte==%d, "
		"dgram_invalid==%d, dgram_past==%d, dgram_future==%d, dgram_dup==%d, dgram_end_marker==%d, "
		"writev_calls==%d, write_partial==%d, write_byte==%d "
		"dgram_pending==%d, dgram_missing==%d",
		dbuf->stats.dgrambuf_read_on_full,
		dbuf->stats.recvmmsg_calls, dbuf->stats.recv_dgrams, dbuf->stats.recv_byte,
		dbuf->stats.dgram_invalid, dbuf->stats.dgram_past, dbuf->stats.dgram_future, dbuf->stats.dgram_dup, dbuf->stats.dgram_end_marker,
		dbuf->stats.writev_calls, dbuf->stats.write_partial, dbuf->stats.write_byte,
		dgram_pending, dgram_missing
	);
}

dgrambuf_t dgrambuf_new(size_t dgram_slots, size_t dgram_max_size, size_t dgram_header_size, size_t iovec_slots) {

	const void **dgram_slot_seq_ptrs = NULL;
	dgrambuf_t dbuf;
	size_t i;

	dbuf = calloc(1, sizeof(struct dgrambuf_t));
	if (!dbuf) goto fail0;

	dbuf->validate_func = NULL;
	/* Implicit with dbuf = calloc(...)
	memset(&(dbuf->stats), 0, sizeof(struct dgrambuf_stats_t));
	memset(&(dbuf->sa_sigalrm), 0, sizeof(struct sigaction));
	*/
	dbuf->sa_sigalrm.sa_handler = _sigalrm_handler;

	dbuf->dgram_slots = dgram_slots;
	dbuf->dgram_max_size = dgram_max_size;
	dbuf->dgram_header_size = dgram_header_size;

	/* writev() and dgrambuf_recvmmsg accumulates read/write bytes in ssize_t */
	iovec_slots = min_size_t(iovec_slots, SSIZE_MAX/dgram_max_size);
	dbuf->iovec_slots = iovec_slots;

	dbuf->msgs = calloc(iovec_slots, sizeof(struct mmsghdr));
	if (!dbuf->msgs) goto fail1;

	dbuf->iov_recv = calloc(iovec_slots, sizeof(struct iovec));
	if (!dbuf->iov_recv) goto fail2;

	dbuf->iov_write = calloc(iovec_slots, sizeof(struct iovec));
	if (!dbuf->iov_write) goto fail3;

	/* Implicit with dbuf = calloc(...)
	dbuf->partial_write_iov = NULL;
	dbuf->partial_write_remaining_iovcnt = 0;
	dbuf->partial_write_remaining_bytes = 0;

	dbuf->dgram_seq_last = 0;
	*/
	dbuf->dgram_seq_base = 1;
	dbuf->dgram_len = calloc(dgram_slots, sizeof(unsigned int));
	if (!dbuf->dgram_len) goto fail4;

	dbuf->dgram_slot_seq = calloc(dgram_slots, sizeof(struct indexed_uint));
	if (!dbuf->dgram_slot_seq) goto fail5;
	for (i=0; i<dgram_slots; i++) {
		dbuf->dgram_slot_seq[i].index = i;
	}

	/* Implicit with dbuf = calloc(...)
	dbuf->dgram_read_active_slots_count = 0;
	*/
	dbuf->dgram_read_active_slots = calloc(iovec_slots, sizeof(struct indexed_uint *));
	if (!dbuf->dgram_read_active_slots) goto fail6;

	/* Implicit with dbuf = calloc(...)
	dbuf->dgram_write_active_slots_count = 0;
	*/
	dbuf->dgram_write_active_slots = calloc(iovec_slots, sizeof(struct indexed_uint *));
	if (!dbuf->dgram_write_active_slots) goto fail7;

	dgram_slot_seq_ptrs = calloc(dgram_slots, sizeof(void *));
	for (i=0; i<dgram_slots; i++) {
		dbuf->dgram_slot_seq[i].index = i;
		dgram_slot_seq_ptrs[i] = &(dbuf->dgram_slot_seq[i]);
	}
	if (!dgram_slot_seq_ptrs) goto fail7;

	dbuf->dgram_empty_slots = gl_list_nx_create(GL_RBTREE_LIST, _equals_indexed_uint,
		 NULL, NULL, false, dgram_slots, dgram_slot_seq_ptrs);
	if (!dbuf->dgram_empty_slots) goto fail8;

	free(dgram_slot_seq_ptrs);
	dgram_slot_seq_ptrs=NULL;

	dbuf->dgram_used_slots = gl_list_nx_create_empty(GL_RBTREE_LIST, _equals_indexed_uint,
		 NULL, NULL, false);
  if (!dbuf->dgram_used_slots) goto fail9;

	dbuf->buf = calloc(dgram_slots, dgram_max_size);
	if (!dbuf->buf) goto fail10;

	return dbuf;

fail10: gl_list_free(dbuf->dgram_used_slots);
fail9:	gl_list_free(dbuf->dgram_empty_slots);
fail8:  free(dbuf->dgram_write_active_slots);
fail7:  free(dbuf->dgram_read_active_slots);
fail6:  free(dbuf->dgram_slot_seq);
fail5:	free(dbuf->dgram_len);
fail4:	free(dbuf->iov_write);
fail3:	free(dbuf->iov_recv);
fail2:	free(dbuf->msgs);
fail1:	free(dbuf);
fail0:	return NULL;
}

void dgrambuf_free(dgrambuf_t *dbuf) {
	if (dbuf && *dbuf) {
		free((*dbuf)->buf);
		gl_list_free((*dbuf)->dgram_used_slots);
		gl_list_free((*dbuf)->dgram_empty_slots);
		free((*dbuf)->dgram_write_active_slots);
		free((*dbuf)->dgram_read_active_slots);
		free((*dbuf)->dgram_slot_seq);
		free((*dbuf)->dgram_len);
		free((*dbuf)->iov_write);
		free((*dbuf)->iov_recv);
		free((*dbuf)->msgs);
		free(*dbuf);
		*dbuf = NULL;
	}
}

void _sigalrm_handler(int signum) {
	/* Nothing to do except interrupting the pending syscall */
	if (signum) {} /* Avoid compiler warning */
}

int _compare_indexed_uint(const void *pa, const void *pb) {
	const struct indexed_uint *a = pa;
	const struct indexed_uint *b = pb;
	if (a->value < b->value)
		return -1;
	else if ( a->value > b->value )
		return 1;
	else
		return 0;
/*
		if (a->index < b->index)
			return -1;
		else if (a->index > b->index)
			return 1;
		else
			return 0;
*/
}

bool _equals_indexed_uint(const void *pa, const void *pb) {
	const struct indexed_uint *a = pa;
	const struct indexed_uint *b = pb;
	return (a->value == b->value) /*&& (a->index == b->index)*/;
}