#!/usr/bin/env python3.7

import argparse
import json
import shutil
import numpy as np
from http.server import HTTPServer, BaseHTTPRequestHandler
import os
import yaml
import signal
import threading
import sys

saved_point = 0
cache = dict()

def append_to_file(filename, array):
    if filename not in cache:
        cache[filename] = []

    cache[filename].append(array)
    if len(cache[filename]) > 100:
        with open(filename, 'ab') as file:
            np.save(file, cache[filename])
        cache[filename] = []
        print(f'saved datapoints')


def get_handler_class(args, worker_id):
    class HandlerClass(BaseHTTPRequestHandler):
        def do_GET(self):
            self.send_response(200)
            self.wfile.write(b"hi!")
            self.end_headers()

        def do_POST(self):
            try:
                content_len = int(self.headers.get('Content-Length'))
                data = self.rfile.read(content_len)
                parsed_data = json.loads(data)

                if self.path == "/raw-input":
                    saving_dir = args.saving_dir
                elif self.path == "/ttp-hidden2":
                    saving_dir = args.saving_dir
                else:
                    raise Exception("Invalid endpoint")

                vec = np.concatenate((np.array(parsed_data["datapoint"]).reshape(-1), 
                                      np.array([parsed_data["format"]])
                                    ))
                append_to_file("{}/{}.npy".format(saving_dir, worker_id), vec)

                global saved_point
                saved_point += 1
                print(saved_point, vec.shape, vec[:5])
                sys.stdout.flush()

                self.send_response(200, "ok")
                self.end_headers()
            except Exception as e:
                print(e)
                self.send_response(400, "error occurred " + str(e))
                self.end_headers()

    return HandlerClass


def run_server(args):
    server_address = (args.addr, args.port)
    handler = get_handler_class(args, 0)
    httpd = HTTPServer(server_address, handler)

    print(f"Starting httpd server on {args.addr}:{args.port}")
    httpd.serve_forever()


def run_servers(yaml_settings):
    num_workers = yaml_settings["experiments"][0]["num_servers"]
    servers = []

    def handler(sug, frame):
        for s in servers:
            s.shutdown()

    signal.signal(signal.SIGINT, handler)

    try:
        threads = []
        for i in range(num_workers):
            server_address = (args.addr, args.port + i)
            handler = get_handler_class(args, i)
            server = HTTPServer(server_address, handler)

            print(f"Starting httpd server on {args.addr}:{args.port + i}")

            servers.append(server)
            threads.append(threading.Thread(target=server.serve_forever))

        for t in threads:
            t.start()
        for t in threads:
            t.join()

    except Exception as e:
        print(e)
    finally:
        for s in servers:
            s.shutdown()


def check_dir(args):
    if os.path.isdir(args.saving_dir):
        shutil.rmtree(args.saving_dir)
    os.makedirs(args.saving_dir)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="HTTP data collect server")
    parser.add_argument(
        "--yaml-settings",
        default='./src/settings.yml'
    )
    parser.add_argument(
        "--addr",
        default="localhost",
    )
    parser.add_argument(
        "--port",
        type=int,
        default=8000,
    )
    parser.add_argument(
        "--saving-dir",
        default="./data_points",
    )
    args = parser.parse_args()
    with open(args.yaml_settings, 'r') as fh:
        yaml_settings = yaml.safe_load(fh)

    check_dir(args)
    run_servers(yaml_settings)
