import os
import json
import pprint
import time
import threading
import torch as th
from types import SimpleNamespace as SN
from utils.logging import Logger
from utils.timer import Timer

from learners import REGISTRY as le_REGISTRY
from runners import REGISTRY as r_REGISTRY
from controllers import REGISTRY as mac_REGISTRY
from controllers.meta_controller import MetaMAC
from controllers.defender_controller import DefenderMAC
from components.episode_buffer import ReplayBuffer
from components.transforms import OneHot

import pickle


def run_ad(_run, _config, _log):
    # check args sanity
    _config = args_sanity_check(_config, _log)

    args = SN(**_config)
    args.device = "cuda" if args.use_cuda else "cpu"

    results_save_dir = args.results_save_dir

    # setup loggers
    logger = Logger(_log, results_save_dir)

    _log.info("Experiment Parameters:")
    experiment_params = pprint.pformat(_config,
                                       indent=4,
                                       width=1)
    _log.info("\n\n" + experiment_params + "\n")

    if args.use_tensorboard and not args.evaluate:
        # only log tensorboard when in training mode
        # though we are always in training mode when we reach here
        tb_exp_direc = os.path.join(results_save_dir, 'tb_logs')
        logger.setup_tb(tb_exp_direc)
    
    if args.use_wandb and not args.evaluate:
        wandb_exp_direc = os.path.join(results_save_dir, 'wandb_logs')
        logger.setup_wandb(wandb_exp_direc, args)

    # set model save dir
    args.save_dir = os.path.join(results_save_dir, 'models')

    # write config file
    config_str = json.dumps(vars(args), indent=4)
    with open(os.path.join(results_save_dir, "config.json"), "w") as f:
        f.write(config_str)

    # sacred is on by default
    logger.setup_sacred(_run)

    # Run and train
    run_sequential(args=args, logger=logger)

    # Clean up after finishing
    print("Exiting Main")

    print("Stopping all threads")
    for t in threading.enumerate():
        if t.name != "MainThread":
            print("Thread {} is alive! Is daemon: {}".format(t.name, t.daemon))
            # t.join(timeout=30)
            t.join(timeout=5)
            print("Thread joined")

    print("Exiting script")

    # Making sure framework really exits
    os._exit(os.EX_OK)


def evaluate_sequential(args, runner):
    for _ in range(args.test_nepisode):
        runner.run(test_mode=True, tag='eval')

    if args.save_replay:
        runner.save_replay()

    runner.close_env()

def evaluate_sequential_test(args, defender_eval_runner):
    import numpy as np
    eval_stats = {}

    file_name = args.checkpoint_path + '.csv' 
    
    for key in ['def1', 'def2', 'def3']:
        eval_stats[key] = {}
        eval_stats[key]['battle_won_mean'] = []
        eval_stats[key]['return_mean'] = []
    for i in range(5):
        for _ in range(args.test_nepisode):
            for key in ['def1', 'def2', 'def3']:
                print('running ' + str(key) + ' episode ' + str(i))
                defender_eval_runner.run(test_mode=True, tag=key, print_attack_obs=(_==0))
        
        for key in ['def1', 'def2', 'def3']:
            defender_log_info = defender_eval_runner.log_info(tag=key)
            eval_stats[key]['battle_won_mean'].append(defender_log_info[key + '/' + 'battle_won_mean'])
            eval_stats[key]['return_mean'].append(defender_log_info[key + '/' + 'return_mean'])
    
    loglog_list = []
    for key in ['def1', 'def2', 'def3']:
        loglog_list.append('{:.3}±{:.3}'.format(np.mean(np.array(eval_stats[key]['battle_won_mean'])), np.std(np.array(eval_stats[key]['battle_won_mean']))))
        loglog_list.append('{:.3}±{:.3}'.format(np.mean(np.array(eval_stats[key]['return_mean'])), np.std(np.array(eval_stats[key]['return_mean']))))
    import csv
    with open(file_name, "a") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(loglog_list)
    defender_eval_runner.close_env()

def evaluate_defender(args, runner):
    for _ in range(args.test_nepisode):
        runner.run(test_mode=True, tag='def')

    if args.save_replay:
        runner.save_replay()

    runner.close_env() 

def run_sequential(args, logger):

    # Init runner so we can get env info
    runner = r_REGISTRY[args.runner](args=args, logger=logger)

    
    defender_eval_runner = r_REGISTRY[args.runner](args=args, logger=logger)

    # Set up schemes and groups here
    env_info = runner.get_env_info()
    args.n_agents = env_info["n_agents"]
    args.n_actions = env_info["n_actions"]
    args.state_shape = env_info["state_shape"]

    # Default/Base scheme
    scheme = {
        "state": {"vshape": env_info["state_shape"]},
        "obs_origin": {"vshape": env_info["obs_shape"], "group": "agents"},
        "obs": {"vshape": env_info["obs_shape"], "group": "agents"},
        "actions": {"vshape": (1,), "group": "agents", "dtype": th.long},
        "avail_actions": {"vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.int},
        "reward": {"vshape": (1,)},
        "ad_discrete_actions": {"vshape": 1, "dtype": th.long},
        "ad_continuous_actions": {"vshape": env_info["obs_shape"]},
        "ad_discrete_emb": {"vshape": args.discrete_action_dim},
        "ad_continuous_emb": {"vshape": args.parameter_action_dim},
        "ad_reward": {"vshape": (1,)},
        "terminated": {"vshape": (1,), "dtype": th.uint8},
    }
    if vars(args).get("defense_loaded", False):
        scheme["obs_perturbed"] = {"vshape": env_info["obs_shape"], "group": "agents"}
    groups = {
        "agents": args.n_agents,
    }
    preprocess = {
        "actions": ("actions_onehot", [OneHot(out_dim=args.n_actions)]),
    }

    buffer = ReplayBuffer(scheme, groups, args.buffer_size, env_info["episode_limit"] + 1,
                          preprocess=preprocess,
                          device="cpu" if args.buffer_cpu_only else args.device)

    # Setup multiagent controller here
    mac = MetaMAC(buffer.scheme, groups, args)

    # Give runner the scheme
    runner.setup(scheme=scheme, groups=groups, preprocess=preprocess, mac=mac)
    
    defender_mac = DefenderMAC(buffer.scheme, groups, args)
    defender_eval_runner.setup(scheme=scheme, groups=groups, preprocess=preprocess, mac=defender_mac)
    

    # Learner
    attacker_learner = le_REGISTRY[mac.attacker_args.learner](
        mac.controller_dict["attacker"][0], buffer.scheme, logger, mac.attacker_args)
    defender_learner = le_REGISTRY[mac.defender_args.learner](
        mac.controller_dict["defender"], buffer.scheme, logger, mac.defender_args)

    if vars(args).get("defense_loaded", False) and vars(mac.defender_args).get("defense_load_path", "") != "":
        timesteps = []
        timestep_to_load = 0

        if not os.path.isdir(mac.defender_args.defense_load_path):
            logger.console_logger.info(
                "Checkpoint directiory {} doesn't exist".format(mac.defender_args.defense_load_path))
            return

        # Go through all files in mac.defender_args.defense_load_path
        for name in os.listdir(mac.defender_args.defense_load_path):
            full_name = os.path.join(mac.defender_args.defense_load_path, name)
            # Check if they are dirs the names of which are numbers
            if os.path.isdir(full_name) and name.isdigit():
                timesteps.append(int(name))

        if mac.defender_args.defense_load_step == 0:
            # choose the max timestep
            timestep_to_load = max(timesteps)
        else:
            # choose the timestep closest to load_step
            timestep_to_load = min(
                timesteps, key=lambda x: abs(x - mac.defender_args.load_step))

        model_path = os.path.join(mac.defender_args.defense_load_path, str(timestep_to_load))

        logger.console_logger.info("Loading defense module from {}".format(model_path))
        mac.controller_dict["defender"].load_defense_models(model_path)

    if args.use_cuda:
        attacker_learner.cuda()
        defender_learner.cuda()

    if args.checkpoint_path != "":

        timesteps = []
        timestep_to_load = 0

        if not os.path.isdir(args.checkpoint_path):
            logger.console_logger.info(
                "Checkpoint directiory {} doesn't exist".format(args.checkpoint_path))
            return

        # Go through all files in args.checkpoint_path
        for name in os.listdir(args.checkpoint_path):
            full_name = os.path.join(args.checkpoint_path, name)
            # Check if they are dirs the names of which are numbers
            if os.path.isdir(full_name) and name.isdigit():
                timesteps.append(int(name))

        if args.load_step == 0:
            # choose the max timestep
            if args.checkpoint_path.find("run_ad") > -1 and max(timesteps) // args.switch_interval % 2 == 1:
                timestep_to_load = min(timesteps, key=lambda x: abs(x - max(timesteps) + args.switch_interval))
            else:
                timestep_to_load = max(timesteps)
        else:
            # choose the timestep closest to load_step
            timestep_to_load = min(
                timesteps, key=lambda x: abs(x - args.load_step))

        model_path = os.path.join(args.checkpoint_path, str(timestep_to_load))

        logger.console_logger.info("Loading model from {}".format(model_path))
        defender_learner.load_models(model_path)
        runner.t_env = timestep_to_load

        if args.evaluate or args.save_replay:
            defender_mac.defender = mac.controller_dict["defender"]
            evaluate_sequential_test(args, defender_eval_runner)
            runner.close_env()
            return

    # start pretraining
    episode = 0
    last_test_T = -args.test_interval - 1
    last_log_T = 0
    model_save_time = 0
    runner.t_env = 0

    timer = Timer()

    logger.console_logger.info(
        "Beginning pretraining for {} timesteps".format(mac.attacker_args.pretrain_tmax))
    pretrain_win_episodes = 0
    while runner.t_env <= mac.attacker_args.pretrain_tmax:
        # Run for a whole episode at a time
        with timer.timer("pretrain_sample"):
            with th.no_grad():
                episode_batch = runner.run(test_mode=True, tag='pretrain')
            # buffer.insert_episode_batch(episode_batch)
            buffer.insert_tuple_batch(episode_batch)
            pretrain_win_episodes += runner.battle_info['battle_won']
        episode += args.batch_size_run
        logger.console_logger.info("Data collecting for {}/{}, pretrain win rate {}%".format(
            runner.t_env, mac.attacker_args.pretrain_tmax, round(pretrain_win_episodes / episode * 100, 2)))

    with timer.timer("pretrain"):
        if runner.t_env > mac.attacker_args.vae_get_c_rate_batch_size:
            # buffer.generate_steps_buffer()
            attacker_learner.vae_train(buffer, mac.attacker_args.vae_batch_size,
                                       mac.attacker_args.training_vae_steps, log_vae_loss=True)

    # start training
    episode = 0
    last_test_T = -args.test_interval - 1
    last_log_T = 0
    model_save_time = 0
    last_train_vae_ep = 0
    last_save_attacker_step = -args.switch_interval - 1
    last_save_defender_step = 0
    runner.t_env = 0
    defender_eval_runner.t_env = 0

    logger.console_logger.info(
        "Beginning training for {} timesteps".format(args.t_max))

    while runner.t_env < args.t_max:

        with timer.timer("sample"):
            with th.no_grad():
                episode_batch = runner.run(test_mode=False, tag='train')
            buffer.insert_episode_batch(episode_batch)
            if (runner.t_env // args.switch_interval) % 2 == 0 and runner.t_env < args.attacker_stop:
                buffer.insert_tuple_batch(episode_batch)
        episode += args.batch_size_run

        with timer.timer("train"):
            if buffer.can_sample(args.batch_size) and buffer.can_sample_tuple(mac.attacker_args.vae_get_c_rate_batch_size):
                for _ in range(args.batch_size_run):
                    episode_sample = buffer.sample(args.batch_size)

                    # Truncate batch to only filled timesteps
                    max_ep_t = episode_sample.max_t_filled()
                    episode_sample = episode_sample[:, :max_ep_t]

                    if episode_sample.device != args.device:
                        episode_sample.to(args.device)
                    if (runner.t_env // args.switch_interval) % 2 == 0 and runner.t_env < args.attacker_stop:
                        # buffer.generate_steps_buffer()
                        for ___ in range(args.attacker_update_repeat):   
                            attacker_learner.train(
                                buffer, runner.t_env, episode)
                    else:
                        for ___ in range(args.defender_update_repeat):
                            defender_learner.train(
                                episode_sample, runner.t_env, episode)
                if (runner.t_env // args.switch_interval) % 2 == 0 and runner.t_env < args.attacker_stop and (episode - last_train_vae_ep) / mac.attacker_args.vae_update_episode_interval >= 1.0:
                    attacker_learner.vae_train(buffer, mac.attacker_args.vae_batch_size, 1, log_vae_loss=False)
                    last_train_vae_ep = episode
            # attacker training ends for a turn
            if (runner.t_env - last_save_attacker_step) / 2 / args.switch_interval >= 1.0 and runner.t_env < args.attacker_stop:
                if args.add_history_attacker:
                    model_save_time = runner.t_env
                    save_path = os.path.join(
                        args.results_save_dir, "models", str(runner.t_env))
                    os.makedirs(save_path, exist_ok=True)
                    logger.console_logger.info(
                        "Saving models to {}".format(save_path))
                    attacker_learner.save_models(save_path)
                    mac.add_test_attacker()
                last_save_attacker_step = runner.t_env
            # defender training ends for a turn
            if (runner.t_env - last_save_defender_step) / 2 / args.switch_interval >= 1.0 and runner.t_env < args.attacker_stop:
                # reinit a new attacker
                last_save_defender_step = runner.t_env
                model_save_time = runner.t_env
                save_path = os.path.join(
                    args.results_save_dir, "models", str(runner.t_env))
                os.makedirs(save_path, exist_ok=True)
                logger.console_logger.info(
                    "Saving defenders to {}".format(save_path))
                defender_learner.save_models(save_path)
                buffer.reset_tuple_self(0)
                if args.reinit_attacker:
                    buffer.reset_tuple_self(0)
                    pretrain_win_episodes = 0
                    pretrain_episode = 0
                    while runner.t_env <= last_save_defender_step + mac.attacker_args.pretrain_tmax:
                        # Run for a whole episode at a time
                        with timer.timer("pretrain_sample"):
                            with th.no_grad():
                                episode_batch = runner.run(test_mode=True, tag='pretrain')
                            buffer.insert_tuple_batch(episode_batch)
                            pretrain_win_episodes += runner.battle_info['battle_won']
                        pretrain_episode += args.batch_size_run
                        logger.console_logger.info("Data collecting for {}/{}, pretrain win rate {}%".format(
                            runner.t_env - last_save_defender_step, mac.attacker_args.pretrain_tmax, round(pretrain_win_episodes / pretrain_episode * 100, 2)))
                    mac.reinit()
                    attacker_learner = le_REGISTRY[mac.attacker_args.learner](
                                    mac.controller_dict["attacker"][0], buffer.scheme, logger, mac.attacker_args)
                    attacker_learner.vae_train(buffer, mac.attacker_args.vae_batch_size,
                                       mac.attacker_args.training_vae_steps, log_vae_loss=True)
                mac.reset_attacker_exploration()

        # Execute test runs once in a while & test when finished
        n_test_runs = max(1, args.test_nepisode // runner.batch_size)
        if (runner.t_env - last_test_T) / args.test_interval >= 1.0 or runner.t_env >= args.t_max:
            logger.console_logger.info(
                "t_env: {} / {}".format(runner.t_env, args.t_max))
            logger.console_logger.info(timer.time_cost_str(
                runner.t_env, last_test_T, args.t_max))

            last_test_T = runner.t_env
            defender_eval_runner.t_env = runner.t_env
            defender_mac.defender = mac.controller_dict["defender"]

            with timer.timer("test"):
                with th.no_grad():
                    for _ in range(n_test_runs):
                        runner.run(test_mode=True, tag='test')
                        defender_eval_runner.run(test_mode=True, tag='def1', print_attack_obs=(_==0))
                        defender_eval_runner.run(test_mode=True, tag='def2', print_attack_obs=(_==0))
                        defender_eval_runner.run(test_mode=True, tag='def3', print_attack_obs=(_==0))

            logger.log_stat("episode", episode, runner.t_env)
            train_log_info = runner.log_info(tag='train')
            test_log_info = runner.log_info(tag='test')
            defender_log_info = defender_eval_runner.log_info(tag='def1')
            defender_log_info = defender_eval_runner.log_info(tag='def2')
            defender_log_info = defender_eval_runner.log_info(tag='def3')
            logger.print_recent_stats()


            if args.save_model and (runner.t_env - model_save_time >= args.save_model_interval or runner.t_env >= args.t_max):
                # if test_log_info['test_battle_won_mean'] > 0.7:
                model_save_time = runner.t_env
                save_path = os.path.join(
                    args.results_save_dir, "models", str(runner.t_env))
                os.makedirs(save_path, exist_ok=True)
                logger.console_logger.info(
                    "Saving models to {}".format(save_path))
                attacker_learner.save_models(save_path)
                defender_learner.save_models(save_path)

    runner.close_env()
    defender_eval_runner.close_env()
    logger.console_logger.info("Finished Training")


def args_sanity_check(config, _log):

    # set CUDA flags
    # config["use_cuda"] = True # Use cuda whenever possible!
    if config["use_cuda"] and not th.cuda.is_available():
        config["use_cuda"] = False
        _log.warning(
            "CUDA flag use_cuda was switched OFF automatically because no CUDA devices are available!")

    if config["test_nepisode"] < config["batch_size_run"]:
        config["test_nepisode"] = config["batch_size_run"]
    else:
        config["test_nepisode"] = (
            config["test_nepisode"]//config["batch_size_run"]) * config["batch_size_run"]

    return config
