import os
import pprint
import time
import threading
import torch as th
from types import SimpleNamespace as SN
from utils.logging import Logger
from utils.timehelper import time_left, time_str
import json

from learners import REGISTRY as le_REGISTRY
from runners import REGISTRY as r_REGISTRY
from controllers import REGISTRY as mac_REGISTRY
from components.episode_buffer import ReplayBuffer
from components.offline_buffer import OfflineBufferWoAgent
from components.transforms import OneHot
from modules.clusters.vqvae import VQVAE
from modules.clusters.RBS_cluster import RBSCLUSTER
from learners.prior_role_vqvae_learner import PriorRoleVqvaeLearner
import numpy as np


def run(_run, _config, _log):

    # check args sanity
    print(_config["use_cuda"])
    _config = args_sanity_check(_config, _log)
    
    args = SN(**_config)
    args.device = "cuda" if args.use_cuda else "cpu"
    

    logger = Logger(_log)

    _log.info("Experiment Parameters:")
    experiment_params = pprint.pformat(_config,
                                       indent=4,
                                       width=1)
    _log.info("\n\n" + experiment_params + "\n")

    results_save_dir = args.results_save_dir
    if args.use_wandb:
        args.use_tensorboard = False
    
    if args.use_swanlab:
        args.use_tensorboard = False
    # assert args.use_tensorboard and args.use_wandb
    
    
    if args.use_tensorboard and not args.evaluate:
        # only log tensorboard when in training mode
        tb_exp_direc = os.path.join(results_save_dir, 'logs')
        logger.setup_tb(tb_exp_direc)
        
    
    
    if args.use_wandb and not args.evaluate:
        wandb_run_name = args.results_save_dir.split('/')
        wandb_run_name = "/".join(wandb_run_name[wandb_run_name.index("results")+1:])
        wandb_exp_direc = os.path.join(results_save_dir, 'logs')
        logger.setup_wandb(wandb_exp_direc, project=args.wandb_project_name, name=wandb_run_name,
                           run_id=args.resume_id, config=args)
    
    if args.use_swanlab and not args.evaluate:
        wandb_run_name = args.results_save_dir.split('/')
        wandb_run_name = "/".join(wandb_run_name[wandb_run_name.index("results")+1:])
        wandb_exp_direc = os.path.join(results_save_dir, 'logs')
        logger.setup_swanlab(wandb_exp_direc, project=args.wandb_project_name, name=wandb_run_name,
                           run_id=args.resume_id, config=args, config_dict=_config)
        
    # write config file
    config_str = json.dumps(vars(args), indent=4)
    with open(os.path.join(results_save_dir, "config.json"), "w") as f:
        f.write(config_str)
    # set model save dir
    args.model_save_dir = os.path.join(results_save_dir, 'models')

    # sacred is on by default
    logger.setup_sacred(_run)

    # Run and train
    run_sequential(args=args, logger=logger)

    logger.finish()

    # Clean up after finishing
    print("Exiting Main")

    print("Stopping all threads")
    for t in threading.enumerate():
        if t.name != "MainThread":
            print("Thread {} is alive! Is daemon: {}".format(t.name, t.daemon))
            t.join(timeout=1)
            print("Thread joined")

    print("Exiting script")

    # Making sure framework really exits
    os._exit(os.EX_OK)

def evaluate_sequential(args, runner):

    for _ in range(args.test_nepisode):
        runner.run(test_mode=True)

    if args.save_replay:
        runner.save_replay()
    
    runner.close_env()


def run_sequential(args, logger):
    # In offline training, we use t_max to denote iterations
    
    # Init runner so we can get env info
    runner = r_REGISTRY[args.runner](args=args, logger=logger)

    # Set up schemes and groups here
    env_info = runner.get_env_info()
    for k, v in env_info.items():
        setattr(args, k, v)
    
    if "n_landmarks" in args.env_args.keys():
        setattr(args, "n_landmarks", args.env_args["n_landmarks"])

    # Default/Base scheme
    scheme = {
        "state": {"vshape": env_info["state_shape"]},
        "obs": {"vshape": env_info["obs_shape"], "group": "agents"},
        "actions": {"vshape": (1,), "group": "agents", "dtype": th.long},
        "avail_actions": {"vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.int},
        "reward": {"vshape": (1,)},
        "terminated": {"vshape": (1,), "dtype": th.uint8},
    }
    groups = {
        "agents": args.n_agents
    }
    preprocess = {
        "actions": ("actions_onehot", [OneHot(out_dim=args.n_actions)])
    }

    if preprocess is not None:
        for k in preprocess:
            assert k in scheme
            new_k = preprocess[k][0]
            transforms = preprocess[k][1]

            vshape = scheme[k]["vshape"]
            dtype = scheme[k]["dtype"]
            for transform in transforms:
                vshape, dtype = transform.infer_output_info(vshape, dtype)

            scheme[new_k] = {
                "vshape": vshape,
                "dtype": dtype
            }
            if "group" in scheme[k]:
                scheme[new_k]["group"] = scheme[k]["group"]
            if "episode_const" in scheme[k]:
                scheme[new_k]["episode_const"] = scheme[k]["episode_const"]

    # buffer = ReplayBuffer(scheme, groups, args.buffer_size, env_info["episode_limit"] + 1,
    #                       preprocess=preprocess,
    #                       device="cpu" if args.buffer_cpu_only else args.device)

    # Setup multiagent controller here
    mac = mac_REGISTRY[args.mac](scheme, groups, args)

    # Give runner the scheme
    runner.setup(scheme=scheme, groups=groups, preprocess=preprocess, mac=mac)


    obs_dim = scheme["obs"]["vshape"]
    vqvae = VQVAE(obs_dim, obs_dim, args)
    
    if args.use_cuda:
        vqvae.cuda()

    # Learner
    learner = PriorRoleVqvaeLearner(mac, vqvae, scheme, logger, args)
    RBS_cluster = RBSCLUSTER(args, qvectors=vqvae.emb.weight.permute(1,0).cpu().detach().numpy())

    if args.use_cuda:
        print("use cuda")
        learner.cuda()


    if args.checkpoint_path != "":
        timesteps = []
        timestep_to_load = 0

        if not os.path.isdir(args.checkpoint_path):
            logger.console_logger.info("Checkpoint directiory {} doesn't exist".format(args.checkpoint_path))
            return

        # Go through all files in args.checkpoint_path
        for name in os.listdir(args.checkpoint_path):
            full_name = os.path.join(args.checkpoint_path, name)
            # Check if they are dirs the names of which are numbers
            if os.path.isdir(full_name) and name.isdigit():
                timesteps.append(int(name))

        if args.load_step == 0:
            # choose the max timestep
            timestep_to_load = max(timesteps)
        else:
            # choose the timestep closest to load_step
            timestep_to_load = min(timesteps, key=lambda x: abs(x - args.load_step))

        model_path = os.path.join(args.checkpoint_path, str(timestep_to_load))

        logger.console_logger.info("Loading model from {}".format(model_path))
        learner.load_models(model_path)

        if args.evaluate or args.save_replay:
            evaluate_sequential(args, runner)
            return
    
    # Create Offline Data
    match args.env:
        case "sc2":
            args.map_name = args.env_args["map_name"]
            data_path = args.offline_data_ls[args.map_name]
        case "sc2_v2":
            args.map_name = args.env_args["task"]
            data_path = args.offline_data_ls[args.map_name]
        case "gymma":
            env_name, args.map_name = args.env_args['key'].split(':')
            args.env = env_name
        case "mt_grid_mpe":
            args.map_name = str(args.env_args["task_id"])
            data_path = args.offline_data_ls[int(args.map_name)]
        case _:
            raise NotImplementedError("Do not support such envs: {}".format(args.env))
        
    offline_buffer = OfflineBufferWoAgent(args, args.map_name, args.offline_data_quality,
                                   data_path, args.offline_max_buffer_size, 
                                   shuffle=False) # device defauly cpu
    
    
    logger.console_logger.info("Beginning  offline training with {} iterations".format(args.role_prior_t_max))
    train_sequential(args, logger, learner, runner, offline_buffer, vqvae, RBS_cluster)

    if args.save_model:
        save_path = os.path.join(args.model_save_dir, str(args.role_prior_t_max))
        os.makedirs(save_path, exist_ok=True)
        logger.console_logger.info("Save final model checkpoint in {}".format(save_path))
        learner.save_models(save_path)
        np.savez(save_path+"/prior_role_id.npz", prior_role_id=offline_buffer.seq_label)
    
    runner.close_env()
    logger.console_logger.info("Finish Training")

def train_sequential(args, logger, learner, runner, offline_buffer, vqvae, RBS_cluster):
    t_env = 0
    episode = 0
    t_max = args.role_prior_t_max
    model_save_time = 0
    last_test_T = 0
    last_log_T = 0
    start_time = time.time()
    last_time = start_time
    test_time_total = 0

    seq_centroid      = None
    seq_centroid_prev = None
    sim_centroid      = None
    sim_centroid_prev = None
    flag_cluster_begin = False
    t_last_cluster_update = 0
    last_cluster_update_episode = 0

    batch_size_train = args.offline_batch_size
    batch_size_run = args.batch_size_run # num of parellel envs
    n_test_runs = max(1, args.test_nepisode//batch_size_run)
    test_start_time = time.time()
    
    test_time_total += time.time() - test_start_time

    while t_env < t_max:
        idx, buffer_seq_labels, episode_sample = offline_buffer.sample(batch_size_train)
        if episode_sample.device != args.device:
            episode_sample.to(args.device)
        
        if not flag_cluster_begin:
            buffer_seq_labels = None
            
   
        # learner.train(episode_sample, t_env, episode)
        visited_seq, _ = learner.train(episode_sample, t_env, episode, seq_centroid=seq_centroid, RBS_cluster=RBS_cluster, buffer_seq_labels=buffer_seq_labels, f_classifier=None)
        offline_buffer.insert_sequence_batch(idx, visited_seq)

        if episode >= args.t_cluster_start:
            if flag_cluster_begin == False or (episode - last_cluster_update_episode >= args.cluster_update_episode_itv):
                last_cluster_update_episode = episode

                if offline_buffer.can_sample_seq(offline_buffer.buffer_size):
                    training_batch_size = offline_buffer.buffer_size
                else:
                    training_batch_size = offline_buffer.num_seq_updated()
                
                if training_batch_size >= args.min_batch_size_clustering:
                    idx_seq, sampled_sequences, sampled_seq_label = offline_buffer.sample_seq(training_batch_size)
                    clusters, reduced_seq, cluster_labels, sim_centroid, meanVq = RBS_cluster.forward(sampled_sequences, prev_centroid=sim_centroid_prev, qvectors=vqvae.emb.weight.permute(1,0).cpu().detach().numpy())
                    if len(clusters) >= args.n_min_cluster:                              
                        sim_centroid_prev = sim_centroid
                        offline_buffer.insert_sequence_label(idx_seq[:,0], th.tensor(cluster_labels))
                        flag_cluster_begin = True

        t_env += 1
        episode += 1
        
        if args.save_model and (t_env-model_save_time >= args.role_prior_save_model_interval or model_save_time==0):
            save_path = os.path.join(args.model_save_dir, str(t_env))
            os.makedirs(save_path, exist_ok=True)
            logger.console_logger.info("Saving models to {}".format(save_path))
            learner.save_models(save_path)
            np.savez(save_path+"/prior_role_id.npz", prior_role_id=offline_buffer.seq_label)
            model_save_time = t_env
        
        if (t_env - last_log_T) >= args.log_interval:
            last_log_T = t_env
            logger.log_stat("episode", episode, t_env)
            logger.print_recent_stats()

    logger.console_logger.info("Finish training sequential")
            
            
def args_sanity_check(config, _log):

    # set CUDA flags
    # config["use_cuda"] = True # Use cuda whenever possible!
    if config["use_cuda"] and not th.cuda.is_available():
        config["use_cuda"] = False
        _log.warning("CUDA flag use_cuda was switched OFF automatically because no CUDA devices are available!")

    if config["test_nepisode"] < config["batch_size_run"]:
        config["test_nepisode"] = config["batch_size_run"]
    else:
        config["test_nepisode"] = (config["test_nepisode"]//config["batch_size_run"]) * config["batch_size_run"]

    return config