import datetime
import os
import pprint
import time
import threading
import torch as th
import json

from types import SimpleNamespace as SN
from utils.logging import Logger
from utils.timehelper import time_left, time_str
from os.path import dirname, abspath
from os import makedirs

from learners import REGISTRY as le_REGISTRY
from runners import REGISTRY as r_REGISTRY
from controllers import REGISTRY as mac_REGISTRY
from components.episode_buffer import ReplayBuffer
from components.transforms import OneHot


def run(_run, _config, _log):

    # check args sanity
    _config = args_sanity_check(_config, _log)

    args = SN(**_config)
    args.device = "cuda" if args.use_cuda else "cpu"

    # Create the local results directory
    if args.local_results_path == "":
        args.local_results_path = dirname(dirname(abspath(__file__)))
    makedirs(args.local_results_path, exist_ok=True)

    # setup loggers
    logger = Logger(_log)

    _log.info("Experiment Parameters:")
    experiment_params = pprint.pformat(_config,
                                       indent=4,
                                       width=1)
    _log.info("\n\n" + experiment_params + "\n")

    # configure tensorboard logger
    date_time = datetime.datetime.now().strftime("%m-%d-%H-%M-%S")
    # f"diff={args.env_args['difficulty']}"
    envargs_list = []  # env args we want to appear in name
    if args.env == "mod_act": 
        envargs_list += [f"env-act={args.env_args['action_mod']}"]
        if args.env_args["action_mod"] == "sticky":
            envargs_list += [f"sticky-prob={args.env_args['sticky_prob']}"]
        elif args.env_args["action_mod"] == "permute": 
            envargs_list += [f"permute={args.env_args['permutation_type']}"]


    algargs_list = [] # [f"act={args.action_selector}"] # alg args we want to appear in name

    namelist = [args.name, args.env, args.label, *envargs_list, *algargs_list, f"seed={args.seed}", date_time]
    namelist = [name.replace("_", "-") for name in namelist if name is not None]
    args.unique_token = "_".join(namelist) 

    try:
        map_name = _config["env_args"]["map_name"]
    except:
        map_name = _config["env_args"]["key"]

    if args.use_tensorboard:
        # tb_logs_direc = os.path.join(args.local_results_path, "tb_logs")
        tb_logs_direc = os.path.join(
            args.local_results_path, "tb_logs", args.env, map_name
        )
        tb_exp_direc = os.path.join(tb_logs_direc, "{}").format(args.unique_token)
        logger.setup_tb(tb_exp_direc)

        # write config file
        config_str = json.dumps(vars(args), indent=4, sort_keys=True)
        with open(os.path.join(tb_exp_direc, "config.json"), "w") as f:
            f.write(config_str)

    # sacred is on by default
    logger.setup_sacred(_run)

    # Run and train
    run_sequential(args=args, logger=logger)

    # Clean up after finishing
    print("Exiting Main")

    print("Stopping all threads")
    for t in threading.enumerate():
        if t.name != "MainThread":
            print("Thread {} is alive! Is daemon: {}".format(t.name, t.daemon))
            t.join(timeout=1)
            print("Thread joined")

    print("Exiting script")

    # Making sure framework really exits
    os._exit(os.EX_OK)


def evaluate_sequential(args, runner, learner):
    if args.save_eval_traj:
        assert args.gail_buffer_size >= args.test_nepisode, "Cannot store all of test_nepisodes in gail buffer"
        for _ in range(args.test_nepisode):
            episode_batch = runner.run(test_mode=False) # when saving traj, want to sample from ippo policy
            learner.insert_episode_batch_gail_only(episode_batch)

        save_path = os.path.join(args.local_results_path, "agents_batches", args.unique_token, "_eval")
        os.makedirs(save_path, exist_ok=True)
        learner.save_traj_data(save_path)

    else:
        for _ in range(args.test_nepisode):
            runner.run(test_mode=True)

    if args.save_replay:
        runner.save_replay()

    runner.close_env()

def get_model_path(checkpoint_path, load_step):
    timesteps = []
    timestep_to_load = 0
    # Go through all files in args.checkpoint_path
    for name in os.listdir(checkpoint_path):
        full_name = os.path.join(checkpoint_path, name)
        # Check if they are dirs the names of which are numbers
        if os.path.isdir(full_name) and name.isdigit():
            timesteps.append(int(name))

    if load_step == 0:
        # choose the max timestep
        timestep_to_load = max(timesteps)
    else:
        # choose the timestep closest to load_step
        timestep_to_load = min(timesteps, key=lambda x: abs(x - load_step))
    model_path = os.path.join(checkpoint_path, str(timestep_to_load))
    print("MODEL PATH IS ", model_path)
    return model_path

def run_sequential(args, logger):

    # Init runner so we can get env info
    print("ENV IS : ", args.env)

    runner = r_REGISTRY[args.runner](args=args, logger=logger)
    # print("SETUP RUNNER")
    # Set up schemes and groups here
    env_info = runner.get_env_info()
    args.n_agents = env_info["n_agents"]
    args.n_actions = env_info["n_actions"]
    args.state_shape = env_info["state_shape"]
    args.episode_limit = env_info["episode_limit"]

    print("EPISODE LIMIT IS ", env_info["episode_limit"])
    # sys.exit(0)
    
    # Default/Base scheme
    scheme = {
        "state": {"vshape": env_info["state_shape"]},
        "obs": {"vshape": env_info["obs_shape"], "group": "agents"},
        "actions": {"vshape": (1,), "group": "agents", "dtype": th.long},
        # TODO: should this be here?
        "rnn_states_actors": {"vshape": (args.rnn_hidden_dim,), "group": "agents"},
        "rnn_states_critics": {"vshape": (args.rnn_hidden_dim,), "group": "agents"},

        "avail_actions": {"vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.int},
        "reward": {"vshape": (1,)},
        "terminated": {"vshape": (1,), "dtype": th.uint8},
    }
    groups = {
        "agents": args.n_agents
    }
    preprocess = {
        "actions": ("actions_onehot", [OneHot(out_dim=args.n_actions)])
    }

    buffer = ReplayBuffer(scheme, groups, args.buffer_size, env_info["episode_limit"] + 1,
                          preprocess=preprocess,
                          device="cpu" if args.buffer_cpu_only else args.device)
    # print("CREATED REPLAY BUFFER")

    # Setup multiagent controller here
    mac = mac_REGISTRY[args.mac](buffer.scheme, groups, args)
    # print("SETUP MAC")

    # Learner
    learner = le_REGISTRY[args.learner](mac, buffer.scheme, logger, args, obs_info=runner.get_obs_info())
    # print("SETUP LEARNER")

    # Give runner the scheme
    runner.setup(scheme=scheme, groups=groups, preprocess=preprocess, mac=mac, ippo_learner=learner)
    # print("SETUP RUNNER")


    if args.use_cuda:
        learner.cuda()

    if args.checkpoint_paths[0] != "":

        # iterate thru checkpoint paths and args.load_steps to get the model paths
        model_paths = []
        for checkpoint_path in args.checkpoint_paths:
            if not os.path.isdir(checkpoint_path):
                logger.console_logger.info("Checkpoint directory {} doesn't exist".format(checkpoint_path))
                return
            model_path = get_model_path(checkpoint_path, args.load_step) # loads the nearest step to args.load_step
            model_paths.append(model_path)

        logger.console_logger.info("Loading models from {}".format(model_paths))
        learner.load_models(model_paths)
        # runner.t_env = timestep_to_load

        if args.evaluate or args.save_replay:
            print("EVALUATING")
            # mac.agent.eval()
            with th.no_grad():
                evaluate_sequential(args, runner, learner)
            return

    # start training
    episode = 0
    last_test_T = -args.test_interval - 1
    last_log_T = 0
    model_save_time = 0
    agent_batch_save_time = 0

    start_time = time.time()
    last_time = start_time

    logger.console_logger.info("Beginning training for {} timesteps".format(args.t_max))

    while runner.t_env <= args.t_max:
        # Run for a whole episode at a time
        episode_batch = runner.run(test_mode=False)
        # TODO: switch this back so it works with other code
        learner.insert_episode_batch(episode_batch)
        # buffer.insert_episode_batch(episode_batch)


        # if buffer.can_sample(args.batch_size):
            # episode_sample = buffer.sample(args.batch_size)

            # # Truncate batch to only filled timesteps
            # max_ep_t = episode_sample.max_t_filled()
            # episode_sample = episode_sample[:, :max_ep_t]

            # if episode_sample.device != args.device:
            #     episode_sample.to(args.device)

        learner.train(runner.t_env)

        # Execute test runs once in a while
        n_test_runs = max(1, args.test_nepisode // runner.batch_size)
        if (runner.t_env - last_test_T) / args.test_interval >= 1.0:

            logger.console_logger.info("t_env: {} / {}".format(runner.t_env, args.t_max))
            logger.console_logger.info("Estimated time left: {}. Time passed: {}".format(
                time_left(last_time, last_test_T, runner.t_env, args.t_max), time_str(time.time() - start_time)))
            last_time = time.time()

            last_test_T = runner.t_env
            # mac.agent.eval()
            with th.no_grad():
                for _ in range(n_test_runs):
                    runner.run(test_mode=True)
            # mac.agent.train(s)

        # save models
        if args.save_model and (runner.t_env - model_save_time >= args.save_model_interval or model_save_time == 0):
            model_save_time = runner.t_env
            try:
                map_name = args.env_args["map_name"]
            except:
                map_name = args.env_args["key"]
            save_path = os.path.join(
                args.local_results_path, "models", args.env, map_name, args.unique_token, str(runner.t_env)
            )
            # save_path = os.path.join(args.local_results_path, "models", args.unique_token, str(runner.t_env))
            #"results/models/{}".format(unique_token)
            os.makedirs(save_path, exist_ok=True)
            logger.console_logger.info("Saving models to {}".format(save_path))

            # learner should handle saving/loading -- delegate actor save/load to,
            # use appropriate filenames to do critics, optimizer states
            learner.save_models(save_path)

        # save traj data for gail
        if args.save_agent_batches and (runner.t_env - agent_batch_save_time >= args.save_agent_batches_interval):
            agent_batch_save_time = runner.t_env
            save_path = os.path.join(args.local_results_path, "agents_batches", args.unique_token, str(runner.t_env))
            os.makedirs(save_path, exist_ok=True)

            # get extra data  into GAIL buffer if necessary
            n_additional_traj = args.gail_buffer_size - args.batch_size 
            for i in range(max(n_additional_traj, 0)):
                episode_batch = runner.run(test_mode=False, get_extra_trajs=True)
                learner.insert_episode_batch_gail_only(episode_batch)
            learner.save_traj_data(save_path)

        episode += args.batch_size_run

        if (runner.t_env - last_log_T) >= args.log_interval:
            logger.log_stat("episode", episode, runner.t_env)
            logger.print_recent_stats()
            last_log_T = runner.t_env

    runner.close_env()
    logger.console_logger.info("Finished Training")


def args_sanity_check(config, _log):

    # set CUDA flags
    # config["use_cuda"] = True # Use cuda whenever possible!
    if config["use_cuda"] and not th.cuda.is_available():
        config["use_cuda"] = False
        _log.warning("CUDA flag use_cuda was switched OFF automatically because no CUDA devices are available!")

    if config["test_nepisode"] < config["batch_size_run"]:
        config["test_nepisode"] = config["batch_size_run"]
    else:
        config["test_nepisode"] = (config["test_nepisode"]//config["batch_size_run"]) * config["batch_size_run"]

    return config
