import datetime
from hashlib import sha1
import os
import pprint
import time
import threading
from numpy import e
import torch as th
from types import SimpleNamespace as SN
from utils.logging import Logger
from utils.timehelper import time_left, time_str
from os.path import dirname, abspath
import json
import numpy as np

from learners import REGISTRY as le_REGISTRY
from runners import REGISTRY as r_REGISTRY
from controllers import REGISTRY as mac_REGISTRY
from components.episode_buffer import ReplayBuffer
from components.transforms import OneHot


def run(_run, _config, _log):

    # check args sanity
    _config = args_sanity_check(_config, _log)

    args = SN(**_config)
    args.device = "cuda" if args.use_cuda else "cpu"

    # setup loggers
    logger = Logger(_log)

    _log.info("Experiment Parameters:")
    experiment_params = pprint.pformat(_config,
                                       indent=4,
                                       width=1)
    _log.info("\n\n" + experiment_params + "\n")

    remark_str = getattr(args, "remark", "nop")
    unique_token = "{}__{}_{}".format(args.name, remark_str, datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    args.unique_token = unique_token
    
    assert args.env == "p2_navigation" or args.env == "dense_p2_navigation"
    training_pop_name = "-".join(args.train_tasks)
    logs_direc = os.path.join(dirname(dirname(abspath(__file__))), "collect", args.task, training_pop_name, unique_token, str(args.env_args["task_id"]))
    
    if getattr(args, "collect_src", False):
        args.traj_save_direc = os.path.join(logs_direc, "src_traj")
    else:
        args.traj_save_direc = os.path.join(logs_direc, "target_traj")
    os.makedirs(args.traj_save_direc, exist_ok=True)

    if getattr(args, "compute_model_error", False):
        args.model_error_save_direc = os.path.join(logs_direc, "model_error")
        os.makedirs(args.model_error_save_direc, exist_ok=True)

    if args.use_tensorboard:
        args.tb_direc = os.path.join(logs_direc, "tb_logs") 
        logger.setup_tb(args.tb_direc)

    # write config file
    config_str = json.dumps(vars(args), indent=4)
    args.config_direc = os.path.join(logs_direc, "config")
    os.makedirs(args.config_direc, exist_ok=True)
    with open(os.path.join(args.config_direc, "config.json"), "w") as f:
        f.write(config_str)
        
    args.repre_save_direc = os.path.join(logs_direc, "task_repre")
    os.makedirs(args.repre_save_direc, exist_ok=True)

    # Run and train
    run_sequential(args=args, logger=logger)

    # Clean up after finishing
    print("Exiting Main")

    print("Stopping all threads")
    for t in threading.enumerate():
        if t.name != "MainThread":
            print("Thread {} is alive! Is daemon: {}".format(t.name, t.daemon))
            t.join(timeout=1)
            print("Thread joined")

    print("Exiting script")

    # Making sure framework really exits
    os._exit(os.EX_OK)

def collect_sequential(args, runner):
    # Collect episodes
    n_collect_runs = max(1, args.test_nepisode // runner.batch_size)
    tot_states = []
    for i in range(n_collect_runs):
        if i == n_collect_runs - 1:
            batch, _ = runner.run(test_mode=True, evaluate_mode=True)
        else:
            batch = runner.run(test_mode=True, evaluate_mode=True)
        tot_states.append(batch["state"])
    tot_states = th.cat(tot_states, dim=0)
    np.save(os.path.join(args.traj_save_direc, "states.npy"), tot_states.cpu().numpy())

    if args.save_replay:
        runner.save_replay()

    runner.close_env()

def model_error_sequential(args, mac, runner):
    # Collect episodes
    n_collect_runs = max(1, args.test_nepisode // runner.batch_size)
    tot_reward_error = []

    for i in range(n_collect_runs):
        with th.no_grad():
            if i == n_collect_runs - 1:
                batch, _ = runner.run(test_mode=True, evaluate_mode=True, model_error_collect=True)
            else:
                batch = runner.run(test_mode=True, evaluate_mode=True, model_error_collect=True)
            
            terminated = batch["terminated"][:, :-1].float()
            mask = batch["filled"][:, :-1].float()
            mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])

            # Calculate prediction loss
            reward_pred = []
            for t in range(batch.max_seq_length):
                reward_preds = mac.task_encoder_forward(batch, t=t)
                reward_pred.append(reward_preds)
            reward_pred = th.stack(reward_pred, dim=1)[:, :-1]
            # get target labels
            # repeated_rewards = batch["reward"][:, :-1].detach().clone().unsqueeze(2).repeat(1, 1, self.task2n_agents[task], 1)
            repeated_rewards = batch["reward"][:, :-1].detach().clone()

            # calculate prediction loss
            pred_reward_loss = ((reward_pred - repeated_rewards) ** 2)
            
            mask = mask.expand_as(pred_reward_loss)
            # do loss mask
            pred_reward_loss = (pred_reward_loss * mask).sum() / mask.sum()

            tot_reward_error.append(pred_reward_loss)

    tot_reward_error = th.stack(tot_reward_error, dim=0)
    np.save(os.path.join(args.model_error_save_direc, "error.npy"), tot_reward_error.cpu().numpy())

    if args.save_replay:
        runner.save_replay()

    runner.close_env()
    
def run_sequential(args, logger):

    # Init runner so we can get env info
    runner = r_REGISTRY[args.runner](args=args, logger=logger)

    # Set up schemes and groups here
    env_info = runner.get_env_info()
    args.n_agents = env_info["n_agents"]
    args.n_actions = env_info["n_actions"]
    args.state_shape = env_info["state_shape"]

    # Default/Base scheme
    scheme = {
        "state": {"vshape": env_info["state_shape"]},
        "obs": {"vshape": env_info["obs_shape"], "group": "agents"},
        "actions": {"vshape": (1,), "group": "agents", "dtype": th.long},
        "avail_actions": {"vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.int},
        "reward": {"vshape": (1,)},
        "terminated": {"vshape": (1,), "dtype": th.uint8},
    }
    groups = {
        "agents": args.n_agents
    }
    preprocess = {
        "actions": ("actions_onehot", [OneHot(out_dim=args.n_actions)])
    }

    buffer = ReplayBuffer(scheme, groups, args.buffer_size, env_info["episode_limit"] + 1,
                          preprocess=preprocess,
                          device="cpu" if args.buffer_cpu_only else args.device)

    # Setup multiagent controller here
    mac = mac_REGISTRY[args.mac](buffer.scheme, groups, args)

    # Give runner the scheme
    runner.setup(scheme=scheme, groups=groups, preprocess=preprocess, mac=mac)

    # Learner
    learner = le_REGISTRY[args.learner](mac, buffer.scheme, logger, args)

    if args.use_cuda:
        learner.cuda()

    if args.checkpoint_path != "":

        timesteps = []
        timestep_to_load = 0

        if not os.path.isdir(args.checkpoint_path):
            logger.console_logger.info("Checkpoint directiory {} doesn't exist".format(args.checkpoint_path))
            return

        # Go through all files in args.checkpoint_path
        for name in os.listdir(args.checkpoint_path):
            full_name = os.path.join(args.checkpoint_path, name)
            # Check if they are dirs the names of which are numbers
            if os.path.isdir(full_name) and name.isdigit():
                timesteps.append(int(name))

        if args.load_step == 0:
            # choose the max timestep
            timestep_to_load = max(timesteps)
        else:
            # choose the timestep closest to load_step
            timestep_to_load = min(timesteps, key=lambda x: abs(x - args.load_step))

        model_path = os.path.join(args.checkpoint_path, str(timestep_to_load))

        logger.console_logger.info("Loading model from {}".format(model_path))
        learner.load_models(model_path)

        ### We don't continue from the loaded timestep
    else:
        raise Exception(f"Should not trained model to do meta_test!")

    if getattr(args, "compute_model_error", False):
        mac.load_task_repre(args.target_repre_path)
        model_error_sequential(args, mac, runner)    
        return 
    
    elif getattr(args, "collect_src", False):
        mac.set_task_repre(task_name=args.task_id)
        print("n run", "$"*20)
        print(args.test_nepisode)
        collect_sequential(args, runner)
        return

    # start training
    episode = 0
    pretrain_phase = getattr(args, "pretrain", True)
    while True:
        # Run for a whole episode at a time
        episode_batch = runner.run(test_mode=False, pretrain_phase=pretrain_phase)
        buffer.insert_episode_batch(episode_batch)

        if buffer.can_sample(args.batch_size):
            # balance between parallel and episode run
            terminated = False
            for _run in range(runner.batch_size):
                episode_sample = buffer.sample(args.batch_size)

                # Truncate batch to only filled timesteps
                max_ep_t = episode_sample.max_t_filled()
                episode_sample = episode_sample[:, :max_ep_t]

                if episode_sample.device != args.device:
                    episode_sample.to(args.device)
                
                terminated = learner.train(episode_sample, runner.t_env, episode)
                if terminated:
                    break

            if terminated:
                break
        # We have no rl training phase
    collect_sequential(args, runner)
    return


def args_sanity_check(config, _log):

    # set CUDA flags
    # config["use_cuda"] = True # Use cuda whenever possible!
    if config["use_cuda"] and not th.cuda.is_available():
        config["use_cuda"] = False
        _log.warning("CUDA flag use_cuda was switched OFF automatically because no CUDA devices are available!")

    if config["test_nepisode"] < config["batch_size_run"]:
        config["test_nepisode"] = config["batch_size_run"]
    else:
        config["test_nepisode"] = (config["test_nepisode"]//config["batch_size_run"]) * config["batch_size_run"]

    return config
