from ast import arg
import os
import json
import pprint
import time
import threading
import torch as th
from types import SimpleNamespace as SN
from utils.logging import Logger
from utils.timer import Timer

from learners import REGISTRY as le_REGISTRY
from runners import REGISTRY as r_REGISTRY
from controllers import REGISTRY as mac_REGISTRY
from controllers.population_controller import PopulationMAC
from components.episode_buffer import ReplayBuffer
from components.transforms import OneHot

import pickle
import numpy as np
import wandb


def run_population(_run, _config, _log):
    # check args sanity
    _config = args_sanity_check(_config, _log)

    args = SN(**_config)
    args.device = "cuda" if args.use_cuda else "cpu"

    results_save_dir = args.results_save_dir

    # setup loggers
    logger = Logger(_log, results_save_dir)

    _log.info("Experiment Parameters:")
    experiment_params = pprint.pformat(_config,
                                       indent=4,
                                       width=1)
    _log.info("\n\n" + experiment_params + "\n")

    if args.use_tensorboard and not args.evaluate:
        # only log tensorboard when in training mode
        # though we are always in training mode when we reach here
        tb_exp_direc = os.path.join(results_save_dir, 'tb_logs')
        logger.setup_tb(tb_exp_direc)

    if args.use_wandb and not args.evaluate:
        wandb_exp_direc = os.path.join(results_save_dir, 'wandb_logs')
        logger.setup_wandb(wandb_exp_direc, args)

    # set model save dir
    args.save_dir = os.path.join(results_save_dir, 'models')

    # write config file
    config_str = json.dumps(vars(args), indent=4)
    with open(os.path.join(results_save_dir, "config.json"), "w") as f:
        f.write(config_str)

    # sacred is on by default
    logger.setup_sacred(_run)

    # Run and train
    run_sequential(args=args, logger=logger)

    # Clean up after finishing
    print("Exiting Main")

    print("Stopping all threads")
    for t in threading.enumerate():
        if t.name != "MainThread":
            print("Thread {} is alive! Is daemon: {}".format(t.name, t.daemon))
            t.join(timeout=1)
            print("Thread joined")

    print("Exiting script")

    # Making sure framework really exits
    os._exit(os.EX_OK)


def evaluate_sequential(args, runner):

    for _ in range(args.test_nepisode):
        runner.run(test_mode=True, tag='eval')

    if args.save_replay:
        runner.save_replay()

    runner.close_env()


def run_sequential(args, logger):

    # Init runner so we can get env info
    runner = r_REGISTRY[args.runner](args=args, logger=logger)

    # Set up schemes and groups here
    env_info = runner.get_env_info()
    args.n_agents = env_info["n_agents"]
    args.n_actions = env_info["n_actions"]
    args.state_shape = env_info["state_shape"]

    # Default/Base scheme
    scheme = {
        "state": {"vshape": env_info["state_shape"]},
        "obs_origin": {"vshape": env_info["obs_shape"], "group": "agents"},
        "obs": {"vshape": env_info["obs_shape"], "group": "agents"},
        "actions": {"vshape": (1,), "group": "agents", "dtype": th.long},
        "avail_actions": {"vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.int},
        "reward": {"vshape": (1,)},
        "ad_discrete_actions": {"vshape": 1, "dtype": th.long},
        "ad_continuous_actions": {"vshape": env_info["obs_shape"]},
        "ad_discrete_emb": {"vshape": args.discrete_action_dim},
        "ad_continuous_emb": {"vshape": args.parameter_action_dim},
        "ad_reward": {"vshape": (1,)},
        "terminated": {"vshape": (1,), "dtype": th.uint8},
    }
    groups = {
        "agents": args.n_agents,
    }
    preprocess = {
        "actions": ("actions_onehot", [OneHot(out_dim=args.n_actions)]),
    }

    buffer = ReplayBuffer(scheme, groups, args.buffer_size, env_info["episode_limit"] + 1,
                          preprocess=preprocess,
                          device="cpu" if args.buffer_cpu_only else args.device)

    # Setup multiagent controller here
    mac = PopulationMAC(buffer.scheme, groups, args)

    # Give runner the scheme
    runner.setup(scheme=scheme, groups=groups, preprocess=preprocess, mac=mac)

    from controllers.defender_controller import DefenderMAC
    defender_eval_runner = r_REGISTRY[args.runner](args=args, logger=logger)
    defender_mac = DefenderMAC(buffer.scheme, groups, args)
    defender_eval_runner.setup(scheme=scheme, groups=groups, preprocess=preprocess, mac=defender_mac)

    # Learner
    # attacker_learner = le_REGISTRY[mac.attacker_args.learner](
    #     mac.controller_dict["attacker"][0], buffer.scheme, logger, mac.attacker_args)
    defender_learner = le_REGISTRY[mac.defender_args.learner](
        mac.controller_dict["defender"], buffer.scheme, logger, mac.defender_args)

    if args.use_cuda:
        # attacker_learner.cuda()
        defender_learner.cuda()

    if args.checkpoint_path != "":

        timesteps = []
        timestep_to_load = 0

        if not os.path.isdir(args.checkpoint_path):
            logger.console_logger.info(
                "Checkpoint directiory {} doesn't exist".format(args.checkpoint_path))
            return

        # Go through all files in args.checkpoint_path
        for name in os.listdir(args.checkpoint_path):
            full_name = os.path.join(args.checkpoint_path, name)
            # Check if they are dirs the names of which are numbers
            if os.path.isdir(full_name) and name.isdigit():
                timesteps.append(int(name))

        if args.load_step == 0:
            # choose the max timestep
            timestep_to_load = max(timesteps)
        else:
            # choose the timestep closest to load_step
            timestep_to_load = min(
                timesteps, key=lambda x: abs(x - args.load_step))

        model_path = os.path.join(args.checkpoint_path, str(timestep_to_load))

        logger.console_logger.info("Loading model from {}".format(model_path))
        defender_learner.load_models(model_path)
        runner.t_env = timestep_to_load

        if args.evaluate or args.save_replay:
            evaluate_sequential(args, runner)
            return

    # start pretraining
    episode = 0
    last_test_T = -args.test_interval - 1
    last_log_T = 0
    model_save_time = 0
    runner.t_env = 0
    last_log_enc_loss_T = 0
    defender_eval_runner.t_env = 0

    timer = Timer()

    logger.console_logger.info(
        "Beginning pretraining for {} timesteps".format(mac.attacker_args.pretrain_tmax))
    for idx, attacker in enumerate(mac.controller_dict["attacker"]):
        attacker_learner = le_REGISTRY[mac.attacker_args.learner](
            attacker, buffer.scheme, logger, mac.attacker_args)
        mac.set_attacker(idx)
        if args.use_cuda:
            attacker_learner.cuda()
        pretrain_win_episodes = 0
        while runner.t_env <= mac.attacker_args.pretrain_tmax:
            # Run for a whole episode at a time
            with timer.timer("pretrain_sample"):
                with th.no_grad():
                    episode_batch = runner.run(test_mode=True, tag='pretrain')
                buffer.insert_episode_batch(episode_batch)
                buffer.insert_tuple_batch(episode_batch)
                pretrain_win_episodes += runner.battle_info['battle_won']
            episode += args.batch_size_run
            logger.console_logger.info("attacker-{}, Data collecting for {}/{}, pretrain win rate {}%".format(
                idx, runner.t_env, mac.attacker_args.pretrain_tmax, round(pretrain_win_episodes / episode * 100, 2)))
            if buffer.can_sample(args.batch_size):
                episode_sample = buffer.sample(args.batch_size)

                # Truncate batch to only filled timesteps
                max_ep_t = episode_sample.max_t_filled()
                episode_sample = episode_sample[:, :max_ep_t]
                episode_sample.to(args.device)
                enc_loss, enc_grad_norm = mac.train_traj_encoder(
                    episode_sample)
                if idx * mac.attacker_args.pretrain_tmax + runner.t_env - last_log_enc_loss_T >= args.log_interval:
                    logger.log_stat("enc/enc_loss", enc_loss, idx *
                                    mac.attacker_args.pretrain_tmax + runner.t_env)
                    logger.log_stat("enc/enc_grad_norm", enc_grad_norm,
                                    idx * mac.attacker_args.pretrain_tmax + runner.t_env)
                    last_log_enc_loss_T = idx * mac.attacker_args.pretrain_tmax + runner.t_env

        with timer.timer("pretrain"):
            if runner.t_env > mac.attacker_args.vae_get_c_rate_batch_size:
                # buffer.generate_steps_buffer()
                attacker_learner.vae_train(buffer, mac.attacker_args.vae_batch_size,
                                           mac.attacker_args.training_vae_steps, log_vae_loss=True)

        episode = 0
        runner.t_env = 0
        buffer.reset_tuple_self(0)
    attackers_last_update = np.zeros(args.n_attackers)
    attackers_last_returns = []
    attackers_last_encode = []
    for idx, attacker in enumerate(mac.controller_dict["attacker"]):
        eval_buffer = ReplayBuffer(scheme, groups, args.batch_size_run, env_info["episode_limit"] + 1,
                                   preprocess=preprocess,
                                   device="cpu" if args.buffer_cpu_only else args.device)
        mac.set_attacker(idx)
        with th.no_grad():
            episode_batch = runner.run(test_mode=True, tag='test')
        eval_buffer.insert_episode_batch(episode_batch)
        episode_sample = eval_buffer.sample(args.batch_size_run)
        # get mask tensors
        rewards = episode_sample["ad_reward"]
        terminated = episode_sample["terminated"].float()
        mask = episode_sample["filled"].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])

        masked_rewards = mask * rewards
        mean_returns = masked_rewards.squeeze(dim=-1).sum(dim=1).mean()

        traj_states = episode_sample["state"].to(args.device)
        mask = mask.to(args.device)
        with th.no_grad():
            encoded_z = mac.traj_encoder(traj_states, mask)
        encoded_z = th.mean(encoded_z, dim=0)   # shape: (emb,)
        attackers_last_returns.append(mean_returns)
        attackers_last_encode.append(encoded_z)

    # start training
    last_test_T = -args.test_interval - 1
    last_log_T = 0
    model_save_time = 0
    last_train_vae_ep = 0
    # last_save_attacker_step = -args.switch_interval - 1
    last_save_defender_step = 0

    if args.use_wandb:
        embeddings_log = attackers_last_encode.copy()
        wandb.log({
            "embeddings": wandb.Table(
                columns=[f"{i}_emb" for i in range(args.enc_emb)],
                data=embeddings_log
            )
        })

    logger.console_logger.info(
        "Beginning training for {} timesteps".format(args.t_max))

    attacker_train_first = True
    defender_train_first = True
    reproduce_buffers = []
    args.n_reproduction = int(args.n_attackers * args.reproduction_ratio)
    last_pretrain_init_attackers_T = 0
    last_train_vae_eps = [0 for _ in range(args.n_reproduction)]
    reproduce_episodes = [0 for _ in range(args.n_reproduction)]
    while runner.t_env < args.t_max:

        if runner.t_env < args.defender_start or (runner.t_env % (args.attacker_train + args.defender_train) < args.attacker_train and runner.t_env < args.t_max - 1000000):
            # NOTE: mutate
            if attacker_train_first:
                mac.init_reproduction_attackers()
                reproduce_buffers = [ReplayBuffer(scheme, groups, 1, env_info["episode_limit"] + 1,
                                                  preprocess=preprocess,
                                                  device="cpu" if args.buffer_cpu_only else args.device) for _ in range(args.n_reproduction)]
                reproduce_learners = [le_REGISTRY[mac.attacker_args.learner](
                    attacker, reproduce_buffers[idx].scheme, logger, mac.attacker_args) for idx, attacker in enumerate(mac.reproduction_attackers)]
                attacker_train_first = False
            for idx, attacker in enumerate(mac.reproduction_attackers):
                mac.set_attacker(args.n_attackers + idx)
                with timer.timer("sample"):
                    with th.no_grad():
                        episode_batch = runner.run(
                            test_mode=False, tag='train')
                    reproduce_buffers[idx].insert_tuple_batch(episode_batch)
                episode += args.batch_size_run
                reproduce_episodes[idx] += args.batch_size_run

            if np.all([reproduce_buffers[idxidx].can_sample_tuple(mac.attacker_args.vae_get_c_rate_batch_size) for idxidx in range(args.n_reproduction)]):
                with timer.timer("train"):
                    def parallel_train(learner_idx):
                        for _ in range(args.batch_size_run):
                            for ___ in range(args.attacker_update_repeat):
                                reproduce_learners[learner_idx].train(
                                    reproduce_buffers[learner_idx], runner.t_env, episode)
                        if (reproduce_episodes[learner_idx] - last_train_vae_eps[learner_idx]) / mac.attacker_args.vae_update_episode_interval >= 1.0:
                            reproduce_learners[learner_idx].vae_train(
                                reproduce_buffers[learner_idx], mac.attacker_args.vae_batch_size, 1, log_vae_loss=False)
                            last_train_vae_eps[learner_idx] = reproduce_episodes[learner_idx]
                    threads = []
                    for idx in range(args.n_reproduction):
                        thd = threading.Thread(
                            target=parallel_train, args=[idx])
                        thd.start()
                        threads.append(thd)
                    # wait for all threading
                    for thd in threads:
                        thd.join()

            # NOTE: first condition: train attackers population for args.defender_start steps
            # second condition: alternatively training of defender and attackers
            if (runner.t_env < args.defender_start and (runner.t_env - last_pretrain_init_attackers_T) // args.switch_interval > 1.0) or \
                (runner.t_env % (args.attacker_train + args.defender_train) >= args.attacker_train or runner.t_env >= args.t_max - 1000000):
                # NOTE: mac.evolve()
                # compare and determine whether to keep it
                for idx, attacker in enumerate(mac.reproduction_attackers):
                    eval_buffer = ReplayBuffer(scheme, groups, args.batch_size_run, env_info["episode_limit"] + 1,
                                               preprocess=preprocess,
                                               device="cpu" if args.buffer_cpu_only else args.device)
                    mac.set_attacker(args.n_attackers + idx)
                    with th.no_grad():
                        episode_batch = runner.run(test_mode=True, tag='test')
                    eval_buffer.insert_episode_batch(episode_batch)
                    episode_sample = eval_buffer.sample(args.batch_size_run)
                    # get mask tensors
                    rewards = episode_sample["ad_reward"]
                    terminated = episode_sample["terminated"].float()
                    mask = episode_sample["filled"].float()
                    mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])

                    masked_rewards = mask * rewards
                    mean_returns = masked_rewards.squeeze(
                        dim=-1).sum(dim=1).mean()

                    traj_states = episode_sample["state"].to(args.device)
                    mask = mask.to(args.device)
                    with th.no_grad():
                        encoded_z = mac.traj_encoder(traj_states, mask)
                    encoded_z = th.mean(encoded_z, dim=0)

                    nearest = min(range(args.n_attackers), key=lambda x: th.norm(
                        attackers_last_encode[x] - encoded_z))
                    oldest = min(range(args.n_attackers),
                                 key=lambda x: attackers_last_update[x])
                    distance = th.norm(
                        attackers_last_encode[nearest] - encoded_z)
                    if distance > args.distance_threshold:
                        # far instance, replace oldest
                        attackers_last_update[oldest] = runner.t_env
                        attackers_last_returns[oldest] = mean_returns
                        attackers_last_encode[oldest] = encoded_z
                        mac.controller_dict["attacker"][oldest].load_state(attacker)
                        if args.use_wandb:
                            embeddings_log.append(encoded_z)
                            wandb.log({
                                "embeddings": wandb.Table(
                                    columns=[f"{i}_emb" for i in range(args.enc_emb)],
                                    data=embeddings_log
                                )
                            })
                    elif mean_returns > attackers_last_returns[nearest]:
                        # good instance, replace nearest
                        attackers_last_update[nearest] = runner.t_env
                        attackers_last_returns[nearest] = mean_returns
                        attackers_last_encode[nearest] = encoded_z
                        mac.controller_dict["attacker"][nearest].load_state(attacker)
                        if args.use_wandb:
                            embeddings_log.append(encoded_z)
                            wandb.log({
                                "embeddings": wandb.Table(
                                    columns=[f"{i}_emb" for i in range(args.enc_emb)],
                                    data=embeddings_log
                                )
                            })
                    else:
                        # bad instance, discard it
                        pass
                attacker_train_first = True
                last_pretrain_init_attackers_T = runner.t_env
        else:
            mac.set_attacker(np.random.choice(range(0, args.n_attackers)))
            with timer.timer("sample"):
                with th.no_grad():
                    episode_batch = runner.run(test_mode=False, tag='train')
                buffer.insert_episode_batch(episode_batch)
            episode += args.batch_size_run
            if buffer.can_sample(args.batch_size):
                with timer.timer("train"):
                    for _ in range(args.batch_size_run):
                        episode_sample = buffer.sample(args.batch_size)

                        # Truncate batch to only filled timesteps
                        max_ep_t = episode_sample.max_t_filled()
                        episode_sample = episode_sample[:, :max_ep_t]
                        episode_sample.to(args.device)
                        enc_loss, enc_grad_norm = mac.train_traj_encoder(
                            episode_sample)
                        if args.n_attackers * mac.attacker_args.pretrain_tmax + runner.t_env - last_log_enc_loss_T >= args.log_interval:
                            logger.log_stat(
                                "enc/enc_loss", enc_loss, args.n_attackers * mac.attacker_args.pretrain_tmax + runner.t_env)
                            logger.log_stat("enc/enc_grad_norm", enc_grad_norm,
                                            args.n_attackers * mac.attacker_args.pretrain_tmax + runner.t_env)
                            last_log_enc_loss_T = args.n_attackers * \
                                mac.attacker_args.pretrain_tmax + runner.t_env
                        for ___ in range(args.defender_update_repeat):
                            defender_learner.train(
                                episode_sample, runner.t_env, episode)

        # Execute test runs once in a while & test when finished
        n_test_runs = max(1, args.test_nepisode // runner.batch_size)
        if (runner.t_env - last_test_T) / args.test_interval >= 1.0 or runner.t_env >= args.t_max:
            logger.console_logger.info(
                "t_env: {} / {}".format(runner.t_env, args.t_max))
            logger.console_logger.info(timer.time_cost_str(
                runner.t_env, last_test_T, args.t_max))

            last_test_T = runner.t_env

            defender_eval_runner.t_env = runner.t_env
            defender_mac.defender = mac.controller_dict["defender"]

            with timer.timer("test"):
                # Test reproduction when attackers are training and test population when defenders are training
                if runner.t_env < args.defender_start or (runner.t_env % (args.attacker_train + args.defender_train) < args.attacker_train and runner.t_env < args.t_max - 1000000):
                    mac.set_attacker(np.random.choice(range(args.n_attackers, args.n_attackers + args.n_reproduction)))
                else:
                    mac.set_attacker(np.random.choice(range(0, args.n_attackers)))
                with th.no_grad():
                    for _ in range(n_test_runs):
                        runner.run(test_mode=True, tag='test')
                        defender_eval_runner.run(test_mode=True, tag='def1')
                        defender_eval_runner.run(test_mode=True, tag='def2')
                        defender_eval_runner.run(test_mode=True, tag='def3')
                        defender_eval_runner.run(test_mode=True, tag='def4')
                        defender_eval_runner.run(test_mode=True, tag='def5')
                        defender_eval_runner.run(test_mode=True, tag='def6')

            logger.log_stat("episode", episode, runner.t_env)
            train_log_info = runner.log_info(tag='train')
            test_log_info = runner.log_info(tag='test')
            defender_log_info = defender_eval_runner.log_info(tag='def1')
            defender_log_info = defender_eval_runner.log_info(tag='def2')
            defender_log_info = defender_eval_runner.log_info(tag='def3')
            defender_log_info = defender_eval_runner.log_info(tag='def4')
            defender_log_info = defender_eval_runner.log_info(tag='def5')
            defender_log_info = defender_eval_runner.log_info(tag='def6')
            logger.print_recent_stats()

            if args.save_model and (runner.t_env - model_save_time >= args.save_model_interval or runner.t_env >= args.t_max):
                # if test_log_info['test_battle_won_mean'] > 0.7:
                model_save_time = runner.t_env
                save_path = os.path.join(
                    args.results_save_dir, "models", str(runner.t_env))
                os.makedirs(save_path, exist_ok=True)
                logger.console_logger.info(
                    "Saving models to {}".format(save_path))
                # attacker_learner.save_models(save_path)
                mac.save_attackers(save_path)
                defender_learner.save_models(save_path)

    runner.close_env()
    logger.console_logger.info("Finished Training")


def args_sanity_check(config, _log):

    # set CUDA flags
    # config["use_cuda"] = True # Use cuda whenever possible!
    if config["use_cuda"] and not th.cuda.is_available():
        config["use_cuda"] = False
        _log.warning(
            "CUDA flag use_cuda was switched OFF automatically because no CUDA devices are available!")

    if config["test_nepisode"] < config["batch_size_run"]:
        config["test_nepisode"] = config["batch_size_run"]
    else:
        config["test_nepisode"] = (
            config["test_nepisode"]//config["batch_size_run"]) * config["batch_size_run"]

    return config
