import datetime
import os
import pprint
import time
import threading
import torch as th
from types import SimpleNamespace as SN
from utils.logging import Logger
from utils.timehelper import time_left, time_str
from os.path import dirname, abspath

from learners import REGISTRY as le_REGISTRY
from runners import REGISTRY as r_REGISTRY
from controllers import REGISTRY as mac_REGISTRY
from components.episode_buffer import ReplayBuffer
from components.transforms import OneHot
from modules.models.model import MI_Action, MI_Obs

import socket
import numpy as np
import random

def _set_runner_mode(runner, eval_mode: bool):
    candidates = [
        getattr(runner, "mac", None),
        getattr(runner, "mi_model", None),
    ]

    for m in candidates:
        if m is None:
            continue
        if hasattr(m, "train") and hasattr(m, "eval"):
            if eval_mode:
                try:
                    m.eval()
                except TypeError:
                    pass
            else:
                try:
                    m.train(True)
                except TypeError:
                    try:
                        m.train()
                    except Exception:
                        pass

def run(_run, _config, _log):
    _config = args_sanity_check(_config, _log)

    args = SN(**_config)
    args.device = "cuda" if args.use_cuda else "cpu"

    logger = Logger(_log)

    _log.info("Experiment Parameters:")
    experiment_params = pprint.pformat(_config,
                                       indent=4,
                                       width=1)
    _log.info("\n\n" + experiment_params + "\n")

    unique_token = "{}__{}".format(args.name, datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    args.unique_token = unique_token
    if args.use_tensorboard:
        tb_logs_direc = os.path.join(dirname(dirname(abspath(__file__))), "results", "tb_logs")
        tb_exp_direc = os.path.join(tb_logs_direc, "{}").format(unique_token)
        logger.setup_tb(tb_exp_direc)

    logger.setup_sacred(_run)

    run_sequential(args=args, logger=logger)

    print("Exiting Main")

    print("Stopping all threads")
    for t in threading.enumerate():
        if t.name != "MainThread":
            print("Thread {} is alive! Is daemon: {}".format(t.name, t.daemon))
            t.join(timeout=1)
            print("Thread joined")

    print("Exiting script")

    os._exit(os.EX_OK)


def evaluate_sequential(args, runner):
    _set_runner_mode(runner, eval_mode=True)

    with th.no_grad():
        wins = 0.0
        N = 1
        for _ in range(N):
            _, battle_won = runner.run(
                fixed_K=args.K, eval_mode=True, test_mode=True,
                obs_attack=True, action_attack=True, remove_attack=False
            )
            wins += float(battle_won)
    if args.save_replay:
        runner.save_replay()

    _set_runner_mode(runner, eval_mode=False)

def eval(args, runner, obs_attack=False, remove_attack=False):

    p_min_set=[0] + [(1/(i)) for i in range(1, args.n_agents+1)]
    p_max_set=[0] + [(1/(i)) for i in range(1, args.n_agents+1)]

    _set_runner_mode(runner, eval_mode=True)
    with th.no_grad():

        N = 16
        wins = 0.0

        if remove_attack == True:
            wins = 0.0

            for _ in range(N):
                _, battle_won = runner.run(
                    fixed_K=1, eval_mode=True, test_mode=True,
                    obs_attack=False, action_attack=False, remove_attack=True
                )
                wins += float(battle_won)

        elif obs_attack == True:
            for _ in range(N):
                _, battle_won = runner.run(
                    fixed_K=args.K, eval_mode=True, test_mode=True,
                    obs_attack=True, action_attack=True, remove_attack=False
                )
                wins += float(battle_won)

        else:
            for _ in range(N):
                _, battle_won = runner.run(
                    fixed_K=0, eval_mode=True, test_mode=True,
                    obs_attack=False, action_attack=False, remove_attack=False
                )
                wins += float(battle_won)

    if args.save_replay:
        runner.save_replay()

    _set_runner_mode(runner, eval_mode=False)

def test(args, runner, p_min, p_max):

    _set_runner_mode(runner, eval_mode=True)
    with th.no_grad():

        for k in range(1, args.K + 1):
            wins = 0.0
            N = 16

            for _ in range(N):
                _, battle_won = runner.run(
                    fixed_K=k, eval_mode=True, test_mode=True,
                    obs_attack=True, action_attack=True, remove_attack=False,
                    action_attack_pro_min=p_min, action_attack_pro_max=p_max
                )
                wins += float(battle_won)

            win_rate = wins / N

            if win_rate >= args.threshold:
                p_max[k] = min(1.0, p_max[k] * 1.1)
            else:
                p_max[k] = max(1.0 / (args.init_prob*k), p_max[k] / 1.1)

    _set_runner_mode(runner, eval_mode=False)

    return p_min, p_max
    
def run_sequential(args, logger):
    runner = r_REGISTRY[args.runner](args=args, logger=logger)

    env_info = runner.get_env_info()
    args.n_agents = env_info["n_agents"]
    args.n_actions = env_info["n_actions"]
    args.state_shape = env_info["state_shape"]
    args.obs_shape = env_info["obs_shape"]
    
    scheme = {
        "state": {"vshape": env_info["state_shape"]},
        "obs": {"vshape": env_info["obs_shape"], "group": "agents"},
        "history": {"vshape": args.rnn_hidden_dim, "group": "agents"},
        "attacked_obs": {"vshape": env_info["obs_shape"], "group": "agents"},
        "actions": {"vshape": (1,), "group": "agents", "dtype": th.long},
        "avail_actions": {"vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.int},
        "reward": {"vshape": (1,)},
        "terminated": {"vshape": (1,), "dtype": th.uint8},
        "group_1": {"vshape": (env_info["n_agents"],), "dtype": th.uint8},  
        "group_2": {"vshape": (env_info["n_agents"],), "dtype": th.uint8}, 
    }

    groups = {
        "agents": args.n_agents
    }
    preprocess = {
        "actions": ("actions_onehot", [OneHot(out_dim=args.n_actions)])
    }

    buffer = ReplayBuffer(scheme, groups, args.buffer_size, env_info["episode_limit"] + 1,
                          preprocess=preprocess,
                          device="cpu" if args.buffer_cpu_only else args.device)


    mac = mac_REGISTRY[args.mac](buffer.scheme, groups, args)

    mi_model = MI_Action(args)
    mi_model_obs = MI_Obs(args)

    runner.setup(scheme=scheme, groups=groups, preprocess=preprocess, mac=mac, mi_model=mi_model, mi_model_obs=mi_model_obs)

    learner = le_REGISTRY[args.learner](mac, buffer.scheme, logger, args, mi_model=mi_model, mi_model_obs=mi_model_obs)

    if args.use_cuda:
        learner.cuda()

    if args.pretrain == True:
        model_path = f"pretrain_models/qmix/" + args.env_args["map_name"] + args.file_name
        learner.load_models(model_path)

    episode = 0
    last_test_T = -args.test_interval - 1
    last_log_T = 0
    model_save_time = 0
    heavy_test_mult = 10
    last_heavy_test_T = -args.test_interval * heavy_test_mult - 1 
    
    start_time = time.time()
    last_time = start_time

    logger.console_logger.info("Beginning training for {} timesteps".format(args.t_max))

    p_min_set=[0] + [(1/(args.init_prob*i)) for i in range(1, args.n_agents+1)]
    p_max_set=[0] + [(1/(args.init_prob*i)) for i in range(1, args.n_agents+1)]

    while runner.t_env <= args.t_max:

        episode_batch, _ = runner.run(test_mode=False, obs_attack=args.obs_attack, action_attack=args.action_attack, remove_attack=args.remove_attack, action_attack_pro_min=p_min_set, action_attack_pro_max=p_max_set)
        buffer.insert_episode_batch(episode_batch)

        if buffer.can_sample(args.batch_size):
            episode_sample = buffer.sample(args.batch_size)

            max_ep_t = episode_sample.max_t_filled()
            episode_sample = episode_sample[:, :max_ep_t]

            if episode_sample.device != args.device:
                episode_sample.to(args.device)

            learner.train(episode_sample, runner.t_env, episode)

        n_test_runs = max(1, args.test_nepisode // runner.batch_size)
        if (runner.t_env - last_test_T) / args.test_interval >= 1.0:

            logger.console_logger.info("t_env: {} / {}".format(runner.t_env, args.t_max))
            logger.console_logger.info("Estimated time left: {}. Time passed: {}".format(
                time_left(last_time, last_test_T, runner.t_env, args.t_max), time_str(time.time() - start_time)))
            last_time = time.time()

            last_test_T = runner.t_env

            eval(args, runner, obs_attack=False, remove_attack=False)
            eval(args, runner, obs_attack=False, remove_attack=True)

            if (runner.t_env - last_heavy_test_T) / (args.test_interval * heavy_test_mult) >= 1.0:
                last_heavy_test_T = runner.t_env
                p_min_set, p_max_set = test(args, runner, p_min_set, p_max_set)

        if args.save_model and (runner.t_env - model_save_time >= args.save_model_interval or model_save_time == 0):
            model_save_time = runner.t_env
            save_path = os.path.join(args.local_results_path, "models", args.env_args["map_name"], args.unique_token, str(runner.t_env))
            os.makedirs(save_path, exist_ok=True)
            logger.console_logger.info("Saving models to {}".format(save_path))

            learner.save_models(save_path)

        episode += args.batch_size_run

        if (runner.t_env - last_log_T) >= args.log_interval:
            logger.log_stat("episode", episode, runner.t_env)
            logger.print_recent_stats()
            last_log_T = runner.t_env

    runner.close_env()
    logger.console_logger.info("Finished Training")


def args_sanity_check(config, _log):

    if config["use_cuda"] and not th.cuda.is_available():
        config["use_cuda"] = False
        _log.warning("CUDA flag use_cuda was switched OFF automatically because no CUDA devices are available!")

    if config["test_nepisode"] < config["batch_size_run"]:
        config["test_nepisode"] = config["batch_size_run"]
    else:
        config["test_nepisode"] = (config["test_nepisode"]//config["batch_size_run"]) * config["batch_size_run"]

    return config
