import os
import json
import pprint
import time
import threading
import torch as th
from types import SimpleNamespace as SN
from utils.logging import Logger
from utils.timer import Timer

from learners import REGISTRY as le_REGISTRY
from runners import REGISTRY as r_REGISTRY
from controllers import REGISTRY as mac_REGISTRY
from components.episode_buffer import ReplayBuffer
from components.transforms import OneHot


def run(_run, _config, _log):
    # check args sanity
    _config = args_sanity_check(_config, _log)

    args = SN(**_config)
    args.device = "cuda" if args.use_cuda else "cpu"

    results_save_dir = args.results_save_dir

    # setup loggers
    logger = Logger(_log, results_save_dir)

    _log.info("Experiment Parameters:")
    experiment_params = pprint.pformat(_config,
                                       indent=4,
                                       width=1)
    _log.info("\n\n" + experiment_params + "\n")

    if args.use_tensorboard and not args.evaluate:
        # only log tensorboard when in training mode
        # though we are always in training mode when we reach here
        tb_exp_direc = os.path.join(results_save_dir, 'tb_logs')
        logger.setup_tb(tb_exp_direc)
    
    if vars(args).get("use_wandb", False) and not args.evaluate:
        wandb_exp_direc = os.path.join(results_save_dir, 'wandb_logs')
        logger.setup_wandb(wandb_exp_direc, args)

    # set model save dir
    args.save_dir = os.path.join(results_save_dir, 'models')

    # write config file
    config_str = json.dumps(vars(args), indent=4)
    with open(os.path.join(results_save_dir, "config.json"), "w") as f:
        f.write(config_str)

    # sacred is on by default
    logger.setup_sacred(_run)

    # Run and train
    run_sequential(args=args, logger=logger)

    # Clean up after finishing
    print("Exiting Main")

    print("Stopping all threads")
    for t in threading.enumerate():
        if t.name != "MainThread":
            print("Thread {} is alive! Is daemon: {}".format(t.name, t.daemon))
            t.join(timeout=30)
            print("Thread joined")

    print("Exiting script")

    # Making sure framework really exits
    os._exit(os.EX_OK)


def evaluate_sequential(args, runner):

    for _ in range(args.test_nepisode):
        runner.run(test_mode=True, tag='eval')

    if args.save_replay:
        runner.save_replay()

    runner.close_env()

def run_sequential(args, logger):

    # Init runner so we can get env info
    runner = r_REGISTRY[args.runner](args=args, logger=logger)

    # Set up schemes and groups here
    env_info = runner.get_env_info()
    args.n_agents = env_info["n_agents"]
    args.n_actions = env_info["n_actions"]
    args.state_shape = env_info["state_shape"]

    # Default/Base scheme
    scheme = {
        "state": {"vshape": env_info["state_shape"]},
        "obs": {"vshape": env_info["obs_shape"], "group": "agents"},
        "actions": {"vshape": (1,), "group": "agents", "dtype": th.long},
        "avail_actions": {"vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.int},
        "reward": {"vshape": (1,)},
        "terminated": {"vshape": (1,), "dtype": th.uint8},
    }
    if vars(args).get("defense_loaded", False):
        scheme["obs_origin"] = {"vshape": env_info["obs_shape"], "group": "agents"}
        scheme["ad_discrete_actions"] = {"vshape": 1, "dtype": th.long}
    if vars(args).get("random_perturbation", False):
        scheme["obs_perturbed"] = {"vshape": env_info["obs_shape"], "group": "agents"}
    groups = {
        "agents": args.n_agents
    }
    preprocess = {
        "actions": ("actions_onehot", [OneHot(out_dim=args.n_actions)])
    }

    buffer = ReplayBuffer(scheme, groups, args.buffer_size, env_info["episode_limit"] + 1,
                          preprocess=preprocess,
                          device="cpu" if args.buffer_cpu_only else args.device)

    # Setup multiagent controller here
    mac = mac_REGISTRY[args.mac](buffer.scheme, groups, args)

    # Give runner the scheme
    runner.setup(scheme=scheme, groups=groups, preprocess=preprocess, mac=mac)

    from controllers.defender_controller import DefenderMAC
    defender_eval_runner = r_REGISTRY[args.runner](args=args, logger=logger)
    defender_mac = DefenderMAC(buffer.scheme, groups, args)
    defender_eval_runner.setup(scheme=scheme, groups=groups, preprocess=preprocess, mac=defender_mac)

    # Learner
    learner = le_REGISTRY[args.learner](mac, buffer.scheme, logger, args)

    if args.use_cuda:
        learner.cuda()

    if args.checkpoint_path != "":

        timesteps = []
        timestep_to_load = 0

        if not os.path.isdir(args.checkpoint_path):
            logger.console_logger.info("Checkpoint directiory {} doesn't exist".format(args.checkpoint_path))
            return

        # Go through all files in args.checkpoint_path
        for name in os.listdir(args.checkpoint_path):
            full_name = os.path.join(args.checkpoint_path, name)
            # Check if they are dirs the names of which are numbers
            if os.path.isdir(full_name) and name.isdigit():
                timesteps.append(int(name))

        if args.load_step == 0:
            # choose the max timestep
            timestep_to_load = max(timesteps)
        else:
            # choose the timestep closest to load_step
            timestep_to_load = min(timesteps, key=lambda x: abs(x - args.load_step))

        model_path = os.path.join(args.checkpoint_path, str(timestep_to_load))

        logger.console_logger.info("Loading model from {}".format(model_path))
        learner.load_models(model_path)
        runner.t_env = timestep_to_load

        if args.evaluate or args.save_replay:
            evaluate_sequential(args, runner)
            return

    # start training
    episode = 0
    last_test_T = -args.test_interval - 1
    last_log_T = 0
    model_save_time = 0
    defender_eval_runner.t_env = 0

    timer = Timer()

    logger.console_logger.info("Beginning training for {} timesteps".format(args.t_max))

    while runner.t_env < args.t_max:

        # Run for a whole episode at a time
        with timer.timer("sample"):
            with th.no_grad():
                episode_batch = runner.run(test_mode=False, tag='train')
            buffer.insert_episode_batch(episode_batch)
        episode += args.batch_size_run

        if buffer.can_sample(args.batch_size):
            with timer.timer("train"):
                for _ in range(args.batch_size_run):
                    episode_sample = buffer.sample(args.batch_size)

                    # Truncate batch to only filled timesteps
                    max_ep_t = episode_sample.max_t_filled()
                    episode_sample = episode_sample[:, :max_ep_t]

                    if episode_sample.device != args.device:
                        episode_sample.to(args.device)

                    learner.train(episode_sample, runner.t_env, episode)

        # Execute test runs once in a while & test when finished
        n_test_runs = max(1, args.test_nepisode // runner.batch_size)
        if (runner.t_env - last_test_T) / args.test_interval >= 1.0 or runner.t_env >= args.t_max:
            logger.console_logger.info("t_env: {} / {}".format(runner.t_env, args.t_max))
            logger.console_logger.info(timer.time_cost_str(runner.t_env, last_test_T, args.t_max))

            last_test_T = runner.t_env
            defender_eval_runner.t_env = runner.t_env
            defender_mac.defender = mac

            with timer.timer("test"):
                with th.no_grad():
                    for _ in range(n_test_runs):
                        runner.run(test_mode=True, tag='test')
                        defender_eval_runner.run(test_mode=True, tag='def1', print_attack_obs=(_==0))
                        defender_eval_runner.run(test_mode=True, tag='def2', print_attack_obs=(_==0))
                        defender_eval_runner.run(test_mode=True, tag='def3', print_attack_obs=(_==0))
            
            logger.log_stat("episode", episode, runner.t_env)
            train_log_info = runner.log_info(tag='train')
            test_log_info = runner.log_info(tag='test')
            defender_log_info = defender_eval_runner.log_info(tag='def1')
            defender_log_info = defender_eval_runner.log_info(tag='def2')
            defender_log_info = defender_eval_runner.log_info(tag='def3')
            logger.print_recent_stats()
            
            if args.save_model and (runner.t_env - model_save_time >= args.save_model_interval or runner.t_env >= args.t_max):
            # if test_log_info['test_battle_won_mean'] > 0.7:
                model_save_time = runner.t_env
                save_path = os.path.join(args.results_save_dir, "models", str(runner.t_env))
                os.makedirs(save_path, exist_ok=True)
                logger.console_logger.info("Saving models to {}".format(save_path))
                learner.save_models(save_path)

    runner.close_env()
    logger.console_logger.info("Finished Training")


def args_sanity_check(config, _log):

    # set CUDA flags
    # config["use_cuda"] = True # Use cuda whenever possible!
    if config["use_cuda"] and not th.cuda.is_available():
        config["use_cuda"] = False
        _log.warning("CUDA flag use_cuda was switched OFF automatically because no CUDA devices are available!")

    if config["test_nepisode"] < config["batch_size_run"]:
        config["test_nepisode"] = config["batch_size_run"]
    else:
        config["test_nepisode"] = (config["test_nepisode"]//config["batch_size_run"]) * config["batch_size_run"]

    return config
