/**
 * layer_mode.c
 * 
 * Emulates the interaction layer side of the communication, which connects to
 * an agent, sending observation packets and receiving action packets back.
 */

#define _POSIX_C_SOURCE 200809L

#include <errno.h>
#include <inttypes.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/select.h>
#include <sys/socket.h>

#include "common.h"


int run_layer(int sock, const struct sockaddr_in *addr, long interval_ms, long n_packets);


/**
 * main function parses arguments. See run_agent_server() for functionality.
 */
int main(int argc, char **argv)
{
    struct sockaddr_in addr;
    long interval_ms;
    long n_packets;

    if (argc != 4) {
        fprintf(stderr, "usage: mock_layer <IP>:<PORT> INTERVAL_MS N_PACKETS\n");
        return 1;
    }

    if (atoip(argv[1], &addr)) {
        fprintf(stderr, "Invalid IP:PORT spec\n");
        return 1;
    }

    errno = 0;
    interval_ms = strtol(argv[2], NULL, 10);
    if (errno == ERANGE || interval_ms < 1) {
        fprintf(stderr, "Invalid ms interval\n");
        return 1;
    }
    errno = 0;
    n_packets = strtol(argv[3], NULL, 10);
    if (errno == ERANGE || n_packets < 1) {
        fprintf(stderr, "Invalid packet count\n");
        return 1;
    }

    uint8_t *a = (uint8_t *) &addr.sin_addr.s_addr;
    fprintf(stderr, "IP address: %hhu.%hhu.%hhu.%hhu\n", a[0], a[1], a[2], a[3]);
    fprintf(stderr, "Port:       %hu\n", ntohs(addr.sin_port));
    fprintf(stderr, "Interval:   %ld ms\n", interval_ms);
    fprintf(stderr, "N Packets:  %ld\n", n_packets);

    // Create a TCP socket
    int sock = socket(AF_INET, SOCK_STREAM, 0);
    if (sock < 0) {
        perror("Could not create TCP socket");
        return 1;
    }

    int code = run_layer(sock, &addr, interval_ms, n_packets);
    if (shutdown(sock, SHUT_RDWR)) {
        perror("Error shutting down socket");
        code = 1;
    }

    return code;
}


/**
 * Runs the interaction layer, connecting to an agent specified by addr.
 */
int run_layer(int sock, const struct sockaddr_in *addr, long interval_ms, long n_packets)
{
    static char sbuf[PACKET_SIZE];
    static char rbuf[PACKET_SIZE];
    ssize_t recv_size = 0;

    time_t interval_ns = interval_ms * 1000000;
    struct timespec next_packet_time, curtime;

    struct timespec *send_times = calloc((size_t) n_packets, sizeof(struct timespec));
    struct timespec *recv_times = calloc((size_t) n_packets, sizeof(struct timespec));
    bool *has_recv = calloc((size_t) n_packets, sizeof(bool));

    if (connect(sock, (const struct sockaddr *) addr, sizeof(struct sockaddr_in))) {
        perror("Could not connect to the agent");
        return 1;
    }

    if (clock_gettime(CLOCK_MONOTONIC, &next_packet_time)) {
        perror("Could not fetch next_packet_time");
        return 1;
    }

    fd_set rfds;
    struct timespec timeout;
    long n_sent = 0;
    long n_recv = 0;

    while (n_sent < n_packets || n_recv < n_packets) {
        // compute time until next packet
        timeout.tv_sec = 0;
        timeout.tv_nsec = 0;
        if (clock_gettime(CLOCK_MONOTONIC, &curtime)) {
            perror("Could not fetch curtime");
            return 1;
        }
        if (curtime.tv_sec < next_packet_time.tv_sec) {
            timeout.tv_sec = next_packet_time.tv_sec - curtime.tv_sec;
        }
        if (curtime.tv_nsec > next_packet_time.tv_nsec) {
            if (timeout.tv_sec > 0) {
                timeout.tv_sec -= 1;
                timeout.tv_nsec = (1000000000 + next_packet_time.tv_nsec) - curtime.tv_nsec;
            }
        } else {
            timeout.tv_nsec = next_packet_time.tv_nsec - curtime.tv_nsec;
        }

        FD_ZERO(&rfds);
        FD_SET(sock, &rfds);

        //fprintf(stderr, "n_recv: %ld | n_sent: %ld\n", n_recv, n_sent);
        //fprintf(stderr, "Timeout: %ld s, %d ms\n", timeout.tv_sec, (int) (timeout.tv_nsec / 1000000));
        //fprintf(stderr, "Next packet: %ld s, %d ms\n", next_packet_time.tv_sec, (int) (next_packet_time.tv_nsec / 1000000));
        //fprintf(stderr, "Curtime: %ld s, %d ms\n", curtime.tv_sec, (int) (curtime.tv_nsec / 1000000));

        int ret = pselect(sock + 1, &rfds, NULL, NULL, &timeout, NULL);
        if (ret == -1) {
            if (errno == EINTR) {
                break;
            } else {
                perror("Error selecting fd");
                return 1;
            }
        } else if (ret == 0) {
            // timed out, maybe time to send a packet
            if (n_sent < n_packets) {
                fprintf(stderr, "Sent (%ld)\n", n_sent);

                uint32_t *u32buf = (uint32_t *) &sbuf[0];
                u32buf[0] = htonl((uint32_t) n_sent);

                ssize_t w = write(sock, sbuf, PACKET_SIZE);
                if (w == -1) {
                    perror("Error sending packet");
                    return 1;
                } else if (w != PACKET_SIZE) {
                    fprintf(stderr, "Did not send expected packet size. Sent: %zu bytes\n", w);
                    return 1;
                }
                if (clock_gettime(CLOCK_MONOTONIC, &send_times[n_sent])) {
                    perror("Could not fetch send time");
                    return 1;
                }
                n_sent++;
            }

            // packet (maybe) sent, set next time we expect to send a packet
            next_packet_time.tv_nsec += interval_ns;
            while (next_packet_time.tv_nsec >= 1000000000) {
                next_packet_time.tv_sec += 1;
                next_packet_time.tv_nsec -= 1000000000;
            }
        } else {
            ssize_t r = recv(sock, &rbuf[recv_size], PACKET_SIZE - recv_size, 0);
            recv_size += r;
            fprintf(stderr, "received some %ld bytes of data\n", r);

            if (r == -1) {
                perror("Error receiving packet");
                return 1;
            } else if (recv_size == PACKET_SIZE) {
                // Received an entire packet now
                recv_size = 0;

                uint32_t *u32buf = (uint32_t *) &rbuf[0];
                uint32_t idx_recv = ntohl(u32buf[0]);

                if (idx_recv >= n_packets) {
                    fprintf(stderr, "Corrupted data received, got idx_recv = %u", idx_recv);
                    return 1;
                } else if (has_recv[idx_recv]) {
                    fprintf(stderr, "Received duplicate message (%u), something went wrong", idx_recv);
                    return 1;
                }
                fprintf(stderr, "Received (%u)\n", idx_recv);

                if (clock_gettime(CLOCK_MONOTONIC, &recv_times[idx_recv])) {
                    perror("Could not fetch recv time");
                    return 1;
                }
                n_recv++;
            }
        }
    }

    fprintf(stderr, "Done. Collected JSON data:\n");
    printf("[\n");
    for (long i = 0; i < n_packets; i++) {
        struct timespec *st = &send_times[i];
        struct timespec *rt = &recv_times[i];
        time_t latency = ((rt->tv_sec - st->tv_sec) * 1000000000) + rt->tv_nsec - st->tv_nsec;
        printf(
            "    {\"idx\": %ld, \"t_sent\": %ld.%09ld, \"t_recv\": %ld.%09ld, \"latency_ms\": %ld.%09ld}",
            i,
            st->tv_sec, st->tv_nsec,
            rt->tv_sec, rt->tv_nsec,
            latency / 1000000, latency % 1000000
        );
        if (i != (n_packets - 1)) {
            printf(",");
        }
        printf("\n");
    }
    printf("]\n");

    return 0;
}
