#!/usr/bin/env python
import argparse
import shutil
import signal
import sys
import gymnasium as gym
import os
import wandb
import ray
from ray import tune
from ray.tune.registry import get_trainable_cls
from ray.tune.logger import pretty_print
import yaml
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from socketServerInputEval import SocketServerInputEval
from socketServerInput import SocketServerInput
from ray.tune.schedulers import PopulationBasedTraining
import time
import numpy as np
import threading
from ray.rllib.algorithms.dqn import DQN

SERVER_ADDRESS = "localhost"
SERVER_BASE_PORT = 9900

CHECKPOINT_FILE = "./weights/models/{}/last_checkpoint.out"
CHECKPOINTS_DIR = "./weights/models/{}/"
INPUT_SIZE = 3152
HIDDEN_SIZE = 102


def get_cli_args():
    """Create CLI parser and return parsed arguments"""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--yaml-settings",
        default='./src/settings.yml'
    )
    parser.add_argument(
        "--run",
        default="DQN",  # APPO, DQN, PPO, BanditLinTS
    )
    parser.add_argument(
        "--resume",
        action="store_true",
        default=False
    )
    parser.add_argument(
        "--type",
        default='inference',
        choices=['tune', 'train', 'inference']
    )

    args = parser.parse_args()
    return args


def get_base_config(args, yaml_settings):
    obs_size = None
    features_algorithm = yaml_settings["experiments"][0]["fingerprint"]["abr_config"]["features"]
    if features_algorithm == "input":
        obs_size = INPUT_SIZE
    else:
        obs_size = HIDDEN_SIZE
    
    config = {
        "env": None,
        "framework": 'torch',
        "observation_space": gym.spaces.Box(float("-inf"), float("inf"), (obs_size,)),
        "action_space": gym.spaces.Discrete(yaml_settings["num_of_qualities"]),
    }

    if args.run == "DQN":
        dqn_config = {
            "exploration_config": {
                "warmup_timesteps": 0 if args.resume else 30000,
                "epsilon_timesteps": 0 if args.resume else 30000,
                "final_epsilon": 0.02,

                # "type": "SoftQ",
                # "temperature": 0.5
            },
            "model": {
                "fcnet_hiddens": [],
                "fcnet_activation": "tanh",
            },
            "target_network_update_freq": 8000,
            # DQN specific configs
            # "grad_clip": None,
            "train_batch_size": 32,
            "double_q": False,
            "dueling": True,
            "n_step": 20,  # https://paperswithcode.com/method/n-step-returns
            "gamma": 0.8,
        }
        config.update(dqn_config)
    elif args.run == "BanditLinTS":
        bandit_config = {
            "observation_space": gym.spaces.Discrete(yaml_settings["experiments"][0]["fingerprint"]["abr_config"]["num_of_contexts"]),
        }
        config.update(bandit_config)
    elif args.run == "APPO":
        appo_config = {
            "model": {
                "fcnet_hiddens": [],
            },
            "minibatch_buffer_size": yaml_settings["experiments"][0]["num_servers"]
        }
        config.update(appo_config)

    return config


def get_train_config(args, yaml_settings):
    config = get_base_config(args=args, yaml_settings=yaml_settings)

    def _input(ioctx):
        if ioctx.worker_index > 0 or ioctx.worker.num_workers == 0:
            return SocketServerInput(
                ioctx,
                SERVER_ADDRESS,
                SERVER_BASE_PORT + ioctx.worker_index -
                (1 if ioctx.worker_index > 0 else 0),
            )
        else:
            return None

    class MyCallbacks(DefaultCallbacks):
        def on_train_result(self, *, algorithm, result: dict, **kwargs):
            if len(result["hist_stats"]["episode_lengths"]) > 0:
                episode_reward = np.array(
                    result["hist_stats"]["episode_reward"]) / np.array(result["hist_stats"]["episode_lengths"])
                result["custom_metrics"]["avg_episode_reward"] = np.mean(
                    episode_reward)
                result["custom_metrics"]["min_episode_reward"] = np.min(
                    episode_reward)
                result["custom_metrics"]["max_episode_reward"] = np.max(
                    episode_reward)
            
    config.update({
        "num_rollout_workers": yaml_settings["experiments"][0]["num_servers"],
        "enable_connectors": False,
        "log_level": "ERROR",
        "explore": True,
        # "evaluation_interval": 1,
        # "evaluation_duration": 1,
        # "evaluation_config": {
        #     "input": _eval_input,
        #     "explore": False
        # },
        # "evaluation_num_workers": 0,

        # "learning_starts": 20000,
        "min_sample_timesteps_per_iteration": 10000,
        "lr": .0000625,
        # "lr_schedule": None if args.resume else [
        #     [0, .0000625],
        #     [200000, .00003]
        # ],
        "input": _input,
        "callbacks": MyCallbacks,
    })

    return config


def get_inference_config(args, yaml_settings):
    config = get_base_config(args=args, yaml_settings=yaml_settings)

    config.update({
        "explore": False,
        "num_rollout_workers": 0
    })

    return config


def get_wandb_config(args, config):
    return {
        "algo": args.run,
        "explore": config.get('explore'),
        "n_step": config.get('n_step'),
        "train_batch_size": config.get('train_batch_size'),
        "timesteps_per_iteration": config.get('timesteps_per_iteration'),
        "lr": config.get('lr'),
        "exploration_config": config.get('exploration_config'),
        "model": config.get('model'),
    }


def train(args, yaml_settings):
    config = get_train_config(args=args, yaml_settings=yaml_settings)
    algo = get_trainable_cls(args.run).get_default_config()
    algo.update_from_dict(config)

    trainer = algo.build()

    # restore weights
    checkpoint_path = CHECKPOINT_FILE.format(args.run)
    if args.resume:
        restore_path = open(checkpoint_path).read()
        trainer.restore(restore_path.strip())
    else:
        print('starting a new experiment')
        time.sleep(3)
        dir = CHECKPOINTS_DIR.format(args.run)
        shutil.rmtree(dir, ignore_errors=True)
        os.makedirs(dir)

    wandb.init(
        mode="online",
        project="local_puffer",
        resume=args.resume,
        sync_tensorboard=True,
        config=get_wandb_config(args, config)
    )

    def close_wandb(sig, frame):
        wandb.finish()
        sys.exit(0)

    signal.signal(signal.SIGINT, close_wandb)

    try:
        for _ in range(500):
            results = trainer.train()
            print(pretty_print(results))
            checkpoint = trainer.save(CHECKPOINTS_DIR.format(args.run))
            with open(checkpoint_path, "w") as f:
                f.write(checkpoint)
            wandb.log(yaml.safe_load(pretty_print(results)))
    finally:
        close_wandb(None, None)



def grid_search_learn(args, yaml_settings):
    config = get_train_config(args=args, yaml_settings=yaml_settings)
    print("Starting grid search")
    hyperparam_mutations = {
        # "grad_clip": tune.sample_from(lambda: random.uniform(0.01, 0.5)),
        "lr": tune.uniform(5e-6, 1e-4),
        # "num_sgd_iter": tune.randint(1, 50),
        "train_batch_size": tune.choice([32, 64, 128, 256]),
        # "lr_schedule": [1, 1e-3, [500, 5e-3]]
        "exploration_config": {
            "final_epsilon": tune.uniform(0.001, 0.05)
        },
    }

    config["lr"] = tune.uniform(lower=5e-6, upper=1e-3)
    config["train_batch_size"] = tune.grid_search([32, 64, 128, 256])
    config["adam_epsilon"] = tune.uniform(lower=1e-8, upper=1e-6)
    config["model"]["fcnet_activation"] = tune.grid_search(
        ["tanh", "relu", "elu"])
    config["exploration_config"]["final_epsilon"] = tune.uniform(0.001, 0.05)
    config["dueling"] = tune.choice([True, False])
    config["double_q"] = tune.choice([True, False])
    # config["exploration_config"]["epsilon_timesteps"] = tune.randint(10000, 20000)

    pbt = PopulationBasedTraining(
        time_attr="training_iteration",
        perturbation_interval=5,
        resample_probability=0.25,
        hyperparam_mutations=hyperparam_mutations
    )
    analysis = tune.run(
        args.run,
        # name=self.experiment_name,
        local_dir=CHECKPOINTS_DIR.format(args.run),
        scheduler=pbt,
        metric="episode_reward_mean",
        mode="max",
        stop=dict(timesteps_total=40000),
        num_samples=1,
        config=config,
        restore=args.resume,
        verbose=0,
        checkpoint_freq=1,
        keep_checkpoints_num=5,
        checkpoint_at_end=True,
    )

    best_config = analysis.best_config
    print(f"Best config: \n{best_config}")



def inference(args, yaml_settings):
    config = get_inference_config(args=args, yaml_settings=yaml_settings)
    algo = DQN(config=config)

    # algo = get_trainable_cls(args.run).get_default_config()
    # algo.update(config)

    # trainer = algo.build()
    
    # # restore weights
    checkpoint_path = CHECKPOINT_FILE.format(args.run)
    restore_path = open(checkpoint_path).read()
    algo.restore(restore_path.strip())

    # algo = Algorithm.from_checkpoint(restore_path.strip())

    num_workers = yaml_settings["experiments"][0]["num_servers"]
    servers = []

    try:
        threads = []
        for i in range(num_workers):
            # here is a problem of creaing double policy servers (with the _input function)
            server = SocketServerInputEval(
                SERVER_ADDRESS,
                SERVER_BASE_PORT + i,
                algo
            )
            server.daemon_threads = True

            servers.append(server)
            t = threading.Thread(target=server.serve_forever)
            t.daemon = True
            threads.append(t)

        for t in threads:
            t.start()
        for t in threads:
            t.join()
    finally:
        for s in servers:
            s.shutdown()


if __name__ == "__main__":
    ray.init()

    args = get_cli_args()
    with open(args.yaml_settings, 'r') as fh:
        yaml_settings = yaml.safe_load(fh)

    if args.type == 'tune':
        grid_search_learn(args, yaml_settings)
    elif args.type == 'train':
        train(args, yaml_settings)
    elif args.type == 'inference':
        inference(args, yaml_settings)
    else:
        raise Exception("invalid args type choice")
