/*
 * dgrambuf.c - C datagrams buffer.
 * 
 * Copyright 2016 by Ludovic Pouzenc <ludovic@pouzenc.fr>
 */
#define _GNU_SOURCE /* See feature_test_macros(7) */

#include "dgrambuf.h"

#include <sys/socket.h> /* recvmmsg() _GNU_SOURCE */
#include <stdlib.h> /* calloc(), free(), qsort() */
#include <stdio.h> /* perror() */
#include <string.h> /* memset() */
#include <sys/uio.h> /* writev() */
#include <sys/param.h> /* MIN() */

struct uint_pair {
	unsigned int index;
	unsigned int value;
};

struct dgrambuf_t {
	size_t dgram_slots;
	size_t dgram_free_count;
	size_t dgram_max_size;
	size_t dgram_header_size;

	size_t iovec_slots;
	struct iovec *iov_recv;
	struct iovec *iov_write;
	struct mmsghdr *msgs;

	unsigned int dgram_seq_base;
	unsigned int *dgram_len;
	unsigned int *dgram_seq_numbers; /* Stores the decoded datagram sequence number for each dgram slot of buf */
	struct uint_pair *dgram_ordered_seq_numbers;

	void *buf;

	unsigned int (*validate_func)(unsigned int, void *);
	//TODO pthread_mutex_lock
};

int _compare_uint_pair(const void *pa, const void *pb);

void dgrambuf_set_validate_func(dgrambuf_t dbuf, unsigned int (*func)(unsigned int, void *) ) {
	dbuf->validate_func = func;
}

size_t dgrambuf_free_count(const dgrambuf_t dbuf) {
	return dbuf->dgram_free_count;
}

int dgrambuf_recvmmsg(dgrambuf_t dbuf, int sockfd) {
  void *dgram_base;
	size_t vlen, i, dgram_index;
	int recv_msg_count, res;
	unsigned int seq, dgram_len;

	/* Buffer is full, can't receive */
	if ( dbuf->dgram_free_count == 0 ) {
		return -1;
	}

	/* Validate function is mandatory */
	if ( !dbuf->validate_func ) {
		return -2;
	}

	/* Initialize recvmmsg() syscall arguments */
	for (i=0, vlen=0; i < dbuf->dgram_slots; i++) {
		if ( dbuf->dgram_seq_numbers[i] == 0 ) {
			dbuf->iov_recv[vlen].iov_base = dbuf->buf + i*dbuf->dgram_max_size;
			dbuf->iov_recv[vlen].iov_len = dbuf->dgram_max_size;
			memset(dbuf->msgs + vlen, 0, sizeof(struct mmsghdr));
			dbuf->msgs[vlen].msg_hdr.msg_iov = dbuf->iov_recv + vlen;
			dbuf->msgs[vlen].msg_hdr.msg_iovlen = 1;
			vlen++;
			if ( vlen == dbuf->iovec_slots )
				break;
		}
	}

	/* Do the syscall */
	recv_msg_count = recvmmsg(sockfd, dbuf->msgs, vlen, MSG_WAITFORONE, NULL);
	if (recv_msg_count < 0) {
		perror("recvmmsg()");
		return recv_msg_count;
	}

	/* Check all received messages */
	res = 1;
	for (i=0; i<recv_msg_count; i++) {
		dgram_base = dbuf->iov_recv[i].iov_base;
		dgram_index = (dgram_base - dbuf->buf) / dbuf->dgram_max_size;
		dgram_len = dbuf->msgs[i].msg_len;
		seq = dbuf->validate_func(dgram_len, dgram_base);
		// TODO better feedback
		if ( seq == 0 ) {
			fprintf(stderr, "dgrambuf_recvmmsg(): #%zi invalid (%u)\n", i, seq);
			dbuf->dgram_seq_numbers[dgram_index] = 0;
		} else if ( seq == -1 ) {
			fprintf(stderr, "dgrambuf_recvmmsg(): #%zi end\n", i);
			dbuf->dgram_seq_numbers[dgram_index] = 0;
			res = 0;
		} else if ( seq < dbuf->dgram_seq_base ) {
			fprintf(stderr, "dgrambuf_recvmmsg(): #%zi past (%u)\n", i, seq);
			dbuf->dgram_seq_numbers[dgram_index] = 0;
		} else if ( seq >= dbuf->dgram_seq_base + dbuf->dgram_slots ) {
			fprintf(stderr, "dgrambuf_recvmmsg(): #%zi future (%u)\n", i, seq);
			dbuf->dgram_seq_numbers[dgram_index] = 0;
		} else {
			//fprintf(stderr, "dgrambuf_recvmmsg(): #%zi valid (%u)\n", i, seq);
			dbuf->dgram_seq_numbers[dgram_index] = seq;
			dbuf->dgram_len[dgram_index] = dgram_len;
			dbuf->dgram_free_count--;
		}
	}

	return res;
}

ssize_t dgrambuf_write(dgrambuf_t dbuf, int fd) {
	size_t dgram_index, i, vlen;
	unsigned int curr_seq, prev_seq, dgram_len;
	ssize_t nwrite, total;

	/* Buffer is empty, nothing to write */
	if ( dbuf->dgram_free_count == dbuf->dgram_slots ) {
		return -1;
	}

	/* Initialize dgram_ordered_seq_numbers from dgram_seq_numbers */
	for (i=0; i < dbuf->dgram_slots; i++) {
		dbuf->dgram_ordered_seq_numbers[i].index = i;
		dbuf->dgram_ordered_seq_numbers[i].value = dbuf->dgram_seq_numbers[i];
	}
	/* Inplace sorting of dgram_ordered_seq_numbers */
	qsort(dbuf->dgram_ordered_seq_numbers, dbuf->dgram_slots, sizeof(struct uint_pair), _compare_uint_pair);
	
	/* Initialize iovecs for writev, take dgram payloads following the sequence numbers */
	prev_seq=0, vlen=0, total=0;
	for (i=dbuf->dgram_free_count; i < dbuf->dgram_slots; i++) {
		curr_seq = dbuf->dgram_ordered_seq_numbers[i].value;

		/* Skip empty dgram slot */
		if ( curr_seq == 0 ) {
			fprintf(stderr, "Oops : found empty slot (i==%zi)\n", i);
			continue;
		}

		/* Skip if current dgram is a dup of the previous */
		if ( curr_seq == prev_seq ) {
			dgram_index = dbuf->dgram_ordered_seq_numbers[i].index;
			/* Mark slot as empty */
			dbuf->dgram_seq_numbers[dgram_index] = 0;
			dbuf->dgram_free_count++;
			continue;
		}

		/* Skip dgram comming from the past */
		if ( curr_seq < dbuf->dgram_seq_base ) {
			fprintf(stderr, "Oops : found dgram from past in buffer (%u)\n", curr_seq);
			continue;
		}

		/* Stop if first dgram to write is not in buffer at all */
		if ( ( vlen==0 ) && (curr_seq != dbuf->dgram_seq_base) ) {
			fprintf(stderr, "Oops : nothing to write, missing %u seq\n", dbuf->dgram_seq_base);
			break;
		}

		/* Stop if current seq dgram is missing */
		if ( ( vlen > 0 ) && (curr_seq > prev_seq+1 ) ) {
			break;
		}

		/* Normal case : curr_seq is the next dgram to write */
		dgram_index = dbuf->dgram_ordered_seq_numbers[i].index;
		dgram_len = dbuf->dgram_len[dgram_index] - dbuf->dgram_header_size;

		dbuf->iov_write[vlen].iov_len = dgram_len; /* Setup iovecs */
		dbuf->iov_write[vlen].iov_base = dbuf->buf + dgram_index*dbuf->dgram_max_size + dbuf->dgram_header_size;
		dbuf->dgram_seq_numbers[dgram_index] = 0; /* Mark dgram slots about to be written out as empty for next read */

		total += dgram_len; /* Update counters */
		dbuf->dgram_free_count++;
		dbuf->dgram_seq_base = curr_seq + 1;
		prev_seq = curr_seq;
		vlen++;
		/* Don't plan to write more than iovec_slots slots */
		if ( vlen == dbuf->iovec_slots )
			break;
	}

	/* Nothing valid to write out (but buffer not empty, missing the next dgram) */
	if ( vlen == 0 ) {
		return -1;
	}

	nwrite = writev(fd, dbuf->iov_write, vlen);
	if ( nwrite < 0 ) {
		perror("writev()");
	} else if ( nwrite != total ) {
		//FIXME : everything break if there because all non writed data will be overwritted at next read
		// Make a loop here could make dgrambuf_writev() unbounded in run time
		fprintf(stderr, "writev() short\n");
	}

	return nwrite;
}

dgrambuf_t dgrambuf_new(size_t dgram_slots, size_t dgram_max_size, size_t dgram_header_size, size_t iovec_slots) {

	dgrambuf_t dbuf = calloc(1, sizeof(struct dgrambuf_t));
	if (!dbuf) goto fail0;

	dbuf->dgram_slots = dgram_slots;
	dbuf->dgram_free_count = dgram_slots;
	dbuf->dgram_max_size = dgram_max_size;
	dbuf->dgram_header_size = dgram_header_size;
	dbuf->iovec_slots = MIN(iovec_slots,dgram_slots);

	dbuf->iov_recv = calloc(iovec_slots, sizeof(struct iovec));
	if (!dbuf->iov_recv) goto fail1;

	dbuf->iov_write = calloc(iovec_slots, sizeof(struct iovec));
	if (!dbuf->iov_write) goto fail2;

	dbuf->msgs = calloc(iovec_slots, sizeof(struct mmsghdr));
	if (!dbuf->msgs) goto fail3;

	dbuf->dgram_seq_base = 1;
	dbuf->dgram_len = calloc(dgram_slots, sizeof(unsigned int));
	if (!dbuf->dgram_len) goto fail4;

	dbuf->dgram_seq_numbers = calloc(dgram_slots, sizeof(unsigned int));
	if (!dbuf->dgram_seq_numbers) goto fail5;

	dbuf->dgram_ordered_seq_numbers = calloc(dgram_slots, sizeof(struct uint_pair));
	if (!dbuf->dgram_ordered_seq_numbers) goto fail6;

	dbuf->buf = calloc(dgram_slots, dgram_max_size);
	if (!dbuf->buf) goto fail7;

	return dbuf;

fail7:  free(dbuf->dgram_ordered_seq_numbers);
fail6:	free(dbuf->dgram_seq_numbers);
fail5:	free(dbuf->dgram_len);
fail4:	free(dbuf->msgs);
fail3:	free(dbuf->iov_write);
fail2:	free(dbuf->iov_recv);
fail1:	free(dbuf);
fail0:	return NULL;
}

void dgrambuf_free(dgrambuf_t *dbuf) {
	if (dbuf && *dbuf) {
		free((*dbuf)->buf);
		free((*dbuf)->dgram_ordered_seq_numbers);
		free((*dbuf)->dgram_seq_numbers);
		free((*dbuf)->dgram_len);
		free((*dbuf)->msgs);
		free((*dbuf)->iov_write);
		free((*dbuf)->iov_recv);
		free(*dbuf);
	}
	*dbuf = NULL;
}

int _compare_uint_pair(const void *pa, const void *pb) {
	const struct uint_pair *a = pa;
	const struct uint_pair *b = pb;
	if (a->value < b->value)
		return -1;
	else if ( a->value > b->value )
		return 1;
	else
		return 0;
}