#!/usr/bin/env python3.7

import argparse
import json
import threading
import shutil
import os
import yaml
import traceback
import sys
import wandb
from exp3.exp3_kmeans import Exp3KMeans
from exp3.exp3_boggart import Exp3Boggart
from socketserver import BaseRequestHandler, ThreadingMixIn, TCPServer
import threading
import traceback
import json
import signal

CHECKPOINT_FILE = "./weights/models/{}/last_checkpoint.out"
CHECKPOINTS_DIR = "./weights/models/{}/"
trainer = None


def get_exp3_config(args, yaml_settings):
    action_space = yaml_settings["num_of_qualities"]
    if yaml_settings["experiments"][0]['fingerprint']['abr_config']['actions_are_delta']:
        action_space = yaml_settings["experiments"][0]["fingerprint"]["abr_config"]["num_of_actions"]

    config = {
        "weights_dir": yaml_settings["experiments"][0]['fingerprint']['abr_config']['weights_dir'],
        "lr": 0.01,
        "observation_space": yaml_settings["experiments"][0]['fingerprint']['abr_config']['num_of_contexts'],
        "action_space": action_space,
        "actions_are_delta": yaml_settings["experiments"][0]['fingerprint']['abr_config']['actions_are_delta'],
        "saving_time": yaml_settings["experiments"][0]['num_servers'],
        "abr": args.run,
        "type": args.type
    }
    return config


def get_trainer(args, yaml_settings):
    config = get_exp3_config(args, yaml_settings)
    use_boggart = yaml_settings["experiments"][0]['fingerprint']['abr_config']['use_boggart']

    if use_boggart:
        return Exp3Boggart(config)
    else:
        return Exp3KMeans(config)


class Handler(BaseRequestHandler):
    def handle(self) -> None:
        LEN = 131072  # 2**17
        while True:
            data = bytearray()
            while len(data) < LEN:
                recv_len = min(4096, LEN - len(data))
                packet = self.request.recv(recv_len).strip()
                data.extend(packet)
            if not data or len(data) == 0:
                print('disconnected')
                return

            args = {}
            try:
                body = data.decode("utf-8").strip()
                body = body[:body.find('\n')]
                args = json.loads(body)
                response = self.execute_command(args)
                print('debug', args, response)

                response_str = json.dumps(response) + '\n'
                response_str = response_str.encode("utf-8")
                response_str = response_str.ljust(
                    100 - len(response_str), b'0')
                self.request.sendall(response_str)
            except Exception:
                print('exception', traceback.format_exc())
            finally:
                if "command" in args and args["command"] == "END_EPISODE":
                    print('close connection END_EPISODE')
                    return


    def execute_command(self, args):
        command = args["command"]
        response = {}

        if command == 'START_EPISODE':
            response["episode_id"] = trainer.start_episode(
                args["episode_id"])
        elif command == 'GET_ACTION':
            response["action"] = int(
                trainer.get_action(args["episode_id"], args["observation"]))
        elif command == "LOG_RETURNS":
            trainer.log_returns(args["episode_id"], args["reward"])
        elif command == "END_EPISODE":
            response = trainer.end_episode(args["episode_id"])
        return response


class SocketServer(ThreadingMixIn, TCPServer):
    def __init__(self, address, port):
        self.port = port
        try:
            TCPServer.__init__(self, (address, port), Handler)
            print(f"Creating a PolicyServer on {address}:{port}")
        except OSError:
            print(f"Creating a PolicyServer on {address}:{port} failed!")
            raise


def start_servers(args, yaml_settings):
    num_workers = yaml_settings["experiments"][0]["num_servers"]
    servers = []

    try:
        threads = []
        for i in range(num_workers):
            server = SocketServer(
                args.addr,
                args.port + i
            )
            server.daemon_threads = True

            servers.append(server)
            t = threading.Thread(target=server.serve_forever)
            t.daemon = True
            threads.append(t)

        for t in threads:
            t.start()
        for t in threads:
            t.join()

    except Exception as e:
        print(e)
    finally:
        for s in servers:
            s.shutdown()


def train(args, yaml_settings):
    global trainer
    trainer = get_trainer(args, yaml_settings)

    checkpoint_path = CHECKPOINT_FILE.format(args.run)
    if args.resume:
        restore_path = open(checkpoint_path).read()
        print('restore', restore_path)
        trainer.restore(restore_path.strip())
    else:
        print('starting a new experiment')
        dir = CHECKPOINTS_DIR.format(args.run)
        shutil.rmtree(dir, ignore_errors=True)
        os.makedirs(dir)

    config = get_exp3_config(args, yaml_settings)
    wandb_config = {
        "algo": config.get('abr'),
        "explore": config.get('explore'),
        "lr": config.get('lr'),
        "context": config.get('observation_space'),
        "actions": config.get('action_space'),
        "actions_are_delta": config.get('actions_are_delta')
    }

    wandb.init(
        mode="online",
        project="clustering",
        resume=args.resume,
        config=wandb_config
    )

    def close_wandb(sig, frame):
        wandb.finish()
        sys.exit(0)

    try:
        start_servers(args, yaml_settings)
    finally:
        close_wandb(None, None)

    signal.signal(signal.SIGINT, close_wandb)


def inference(args, yaml_settings):
    global trainer
    trainer = get_trainer(args, yaml_settings)

    checkpoint_path = CHECKPOINT_FILE.format(args.run)
    restore_path = open(checkpoint_path).read()
    trainer.restore(restore_path.strip())
    print('loaded', restore_path)

    start_servers(args, yaml_settings)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Algo. train server")
    parser.add_argument(
        "--run",
        default="Exp3",
    )
    parser.add_argument(
        "--addr",
        default="localhost",
    )
    parser.add_argument(
        "--port",
        type=int,
        default=9900,
    )
    parser.add_argument(
        "--yaml-settings",
        default='./src/settings.yml'
    )
    parser.add_argument(
        "--resume",
        action="store_true",
        default=False
    )
    parser.add_argument(
        "--type",
        default='inference',
        choices=['train', 'inference']
    )
    args = parser.parse_args()

    with open(args.yaml_settings, 'r') as fh:
        yaml_settings = yaml.safe_load(fh)

    if args.type == 'train':
        train(args, yaml_settings)
    elif args.type == 'inference':
        inference(args, yaml_settings)
