import os
import pprint
import time
import threading
import torch as th
import yaml
from types import SimpleNamespace as SN
from utils.logging import Logger
from utils.timehelper import time_left, time_str
from os.path import dirname, abspath
from pathlib import Path
import json
from tqdm import tqdm
import shutil
import copy
import collections
from enum import Enum
import pickle

from learners import REGISTRY as le_REGISTRY
from runners import REGISTRY as r_REGISTRY
from controllers import REGISTRY as mac_REGISTRY
from components.episode_buffer import ReplayBuffer
from components.offline_buffer import OfflineBuffer, PartialOfflineBuffer, DataSaver
from components.transforms import OneHot
# from controllers.transfer.tr_basic_controller import TrBasicMAC

import numpy as np

def recursive_sn_update(d, u):
    for k, v in u.items():
        if isinstance(d, SN):
            if isinstance(v, collections.abc.Mapping):
                if not hasattr(d, k):
                    setattr(d, k, type(d)())
                recursive_sn_update(getattr(d, k), v)
            else:
                setattr(d, k, v)
        else:
            if isinstance(v, collections.abc.Mapping):
                d[k] = recursive_sn_update(d.get(k, {}), v)
            else:
                d[k] = v
    return d

def run(_run, _config, _log):

    # check args sanity
    _config = args_sanity_check(_config, _log)
    
    args = SN(**_config)
    args.device = "cuda" if args.use_cuda else "cpu"

    logger = Logger(_log)
    print(args.offline_data_quality)
    _log.info("Experiment Parameters:")
    experiment_params = pprint.pformat(_config,
                                       indent=4,
                                       width=1)
    _log.info("\n\n" + experiment_params + "\n")

    results_save_dir = args.results_save_dir
    
    if args.use_tensorboard and not args.evaluate:
        # only log tensorboard when in training mode
        tb_exp_direc = os.path.join(results_save_dir, 'tb_logs', _config['wandb_note'].replace(' ', '_'))
        logger.setup_tb(tb_exp_direc)
    
    if args.use_wandb:
        logger.setup_wandb(
            _config, args.wandb_team, args.wandb_project, args.wandb_mode
        )
    
    # save executing code
    def _file_ignore(path, content):
        ignore_list = ['__pycache__', '.pyc', '.git', '.pdf', '.ipynb']
        # ignore = [ f for f in content if '__pycache__' in f or '.pyc' in f or '.git' in f or '.pdf']
        ignore = []
        for f in content:
            for i in ignore_list:
                if i in f:
                    ignore.append(f)
        return ignore
    
    args.code_save_dir = os.path.join(results_save_dir, 'code')
    # os.makedirs(args.code_save_dir, exist_ok=True)
    shutil.copytree(dirname(abspath(__file__)), args.code_save_dir, ignore=_file_ignore)

    # set model save dir
    args.save_dir = os.path.join(results_save_dir, 'models')

    # write config file
    config_str = json.dumps(vars(args), indent=4)
    with open(os.path.join(results_save_dir, "config.json"), "w") as f:
        f.write(config_str)

    # sacred is on by default
    logger.setup_sacred(_run)

    # Run and train
    run_sequential(args=args, logger=logger)
        
    # Clean up after finishing
    print("Exiting Main")

    print("Stopping all threads")
    for t in threading.enumerate():
        if t.name != "MainThread":
            print("Thread {} is alive! Is daemon: {}".format(t.name, t.daemon))
            t.join(timeout=1)
            print("Thread joined")

    print("Exiting script")
    
    # Finish logging
    logger.finish()

    # Making sure framework really exits
    os._exit(os.EX_OK)

def init_tasks(task_list, main_args, logger):
    task2args, task2runner, task2buffer = {}, {}, {}
    task2scheme, task2groups, task2preprocess = {}, {}, {}

    for task in task_list:
        task_args = copy.deepcopy(main_args)
        
        if main_args.task in ["sc2_v2_large"]:
            task_map = task.split('_')[1]
            with open(os.path.join(os.path.dirname(__file__), "config", "envs", f"sc2_v2_{task_map}.yaml")) as f:
                try:
                    task_config = yaml.load(f)
                except yaml.YAMLError as exc:
                    assert False, f"{task_map}.yaml error: {exc}"
                    
            recursive_sn_update(task_args, task_config)
        
        if main_args.env in ["sc2", "sc2_v2"]:
            task_args.env_args["map_name"] = task
        elif main_args.env == "gymma":
            task_args.env_args["key"] = task
        elif main_args.env == "grid_mpe":
            task_plst = task.split('-')
            task_args.env_args["task_id"] = task_plst[1]
            if len(task_plst) == 3:
                task_args.env_args["layouts"] = task_plst[2]
        elif main_args.env == 'mamujoco':
            task_args.env_args["task"] = task
        else:
            assert 0
        
        task2args[task] = task_args

        task_runner = r_REGISTRY[main_args.runner](args=task_args, logger=logger, task=task)
        task2runner[task] = task_runner

        # Set up schemes and groups here
        env_info = task_runner.get_env_info()
        for k, v in env_info.items():
            setattr(task_args, k, v)

        # Default/Base scheme
        scheme = {
            "state": {"vshape": env_info["state_shape"]},
            "obs": {"vshape": env_info["obs_shape"], "group": "agents"},
            "actions": {"vshape": (1,), "group": "agents", "dtype": th.long},
            "avail_actions": {"vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.int},
            "reward": {"vshape": (1,)},
            "terminated": {"vshape": (1,), "dtype": th.uint8},
        }
        groups = {
            "agents": task_args.n_agents
        }
        preprocess = {
            "actions": ("actions_onehot", [OneHot(out_dim=task_args.n_actions)])
        }

        task2buffer[task] = ReplayBuffer(scheme, groups, 1, env_info["episode_limit"] + 1,
                                   preprocess=preprocess,
                                   device="cpu" if task_args.buffer_cpu_only else task_args.device)

        # store task information
        task2scheme[task], task2groups[task], task2preprocess[task] = scheme, groups, preprocess
    
    return task2args, task2runner, task2buffer, task2scheme, task2groups, task2preprocess

def run_sequential(args, logger):
    train_tasks = args.train_tasks
    trans_tasks = args.trans_tasks
    test_tasks = args.test_tasks
    all_tasks = list(set(train_tasks + trans_tasks + test_tasks))
    args.all_tasks = all_tasks
    args.n_tasks = len(all_tasks)

    main_args = copy.deepcopy(args)

    task2args, task2runner, task2buffer, task2scheme, task2groups, task2preprocess = init_tasks(all_tasks, main_args, logger)
    task2buffer_scheme = { task: task2buffer[task].scheme for task in all_tasks }

    main_args.t_max = main_args.cont_train_steps
    
    if main_args.mac in ['tr_basic_mac', 'tr_ddpg_mac', 'tr_comad_mac', 'tr_comad_gru_mac']:
        mac = mac_REGISTRY[main_args.mac](all_tasks, train_tasks, trans_tasks, task2scheme=task2buffer_scheme, task2args=task2args, main_args=main_args)
    else:
        mac = mac_REGISTRY[main_args.mac](all_tasks, task2scheme=task2buffer_scheme, task2args=task2args, main_args=main_args)
        
    learner = le_REGISTRY[main_args.learner](mac, logger, main_args)
    if main_args.use_cuda:
        learner.cuda()
    
    for task in all_tasks: 
        task2runner[task].setup(scheme=task2scheme[task], groups=task2groups[task], preprocess=task2preprocess[task], mac=mac)
        
    model_path = None
    if main_args.checkpoint_path != "":
        timesteps = []
        timestep_to_load = 0
            
        if not os.path.isdir(main_args.checkpoint_path):
            logger.console_logger.info("Checkpoint directiory {} doesn't exist".format(main_args.checkpoint_path))
            return

        # Go through all files in args.checkpoint_path
        for name in os.listdir(main_args.checkpoint_path):
            full_name = os.path.join(main_args.checkpoint_path, name)
            # Check if they are dirs the names of which are numbers
            if os.path.isdir(full_name) and name.isdigit():
                timesteps.append(int(name))

        if main_args.load_step == 0:
            # choose the max timestep
            timestep_to_load = max(timesteps)
        else:
            # choose the timestep closest to load_step
            timestep_to_load = min(timesteps, key=lambda x: abs(x - main_args.load_step))

        model_path = os.path.join(main_args.checkpoint_path, str(timestep_to_load))

        logger.console_logger.info("Loading model from {}".format(model_path))
        learner.load_models(model_path)
    
    logger.console_logger.info("Beginning preparing offline datasets")
    if main_args.name in ['omiga_mt']:
        trans_tasks = train_tasks
    task2offline_buffer = {}
    for task in trans_tasks:
        task2offline_buffer[task] = OfflineBuffer(
            main_args.env,
            map_name=task,
            quality=main_args.train_tasks_data_quality[task],
            offline_data_size=main_args.offline_max_buffer_size,
            random_sample=main_args.offline_data_shuffle,
            # val=True, # NOTE only for testing EBM
        )
    logger.console_logger.info("Beginning offline continual training")
    
    # comad; fc, ft, mt; ewc, owl ,re.
    if main_args.name in ['tr_qmix', 'updet-l', 'omiga', 'icq']:
        cont_train_baseline(trans_tasks, main_args, logger, learner, task2args, task2runner, task2offline_buffer)
        
    elif main_args.name in ['rehearsal', 'omiga_re']:
        parbuf = PartialOfflineBuffer(trans_tasks, task2offline_buffer, args.n_sample_per_task)
        if hasattr(learner, 'preset'):
            learner.preset('re')
        cont_train_OWL(trans_tasks, main_args, logger, learner, task2args, task2runner, task2offline_buffer, parbuf) # simulating MT, need no reinit
        
    elif main_args.name in ['owl', 'comad', 'comad_gru', 'omiga_cont']:
        cont_train_OWL(trans_tasks, main_args, logger, learner, task2args, task2runner, task2offline_buffer)
    
    elif main_args.name in ['omiga_mt']:
        main_args.t_max = main_args.offline_train_steps
        if hasattr(learner, 'preset'):
            learner.preset('mt')
        mt_train_baselines(train_tasks, main_args, logger, learner, task2args, task2runner, task2offline_buffer)
        
    elif main_args.name in ['odis_mt']:
        mt_train_ODIS(main_args.pretrain_tasks, main_args, logger, learner, task2args, task2runner, task2offline_buffer, pretrain=True, test_task2offlinedata=task2offline_buffer)
        logger.console_logger.info(f"Finished pretraining.")
        mt_train_ODIS(train_tasks, main_args, logger, learner, task2args, task2runner, task2offline_buffer)
        
    elif main_args.name in ['odis_ft', 'odis_re']:
        if main_args.name.endswith('_re'):
            parbuf = PartialOfflineBuffer(trans_tasks, task2offline_buffer, args.n_sample_per_task)
        else:
            parbuf = None
        mt_train_ODIS(main_args.pretrain_tasks, main_args, logger, learner, task2args, task2runner, task2offline_buffer, pretrain=True, test_task2offlinedata=task2offline_buffer)
        logger.console_logger.info(f"Finished pretraining.")
        cont_train_baseline(train_tasks, main_args, logger, learner, task2args, task2runner, task2offline_buffer, parbuf=parbuf)
    else:
        raise NotImplementedError(main_args.name)
    
    logger.console_logger.info(f"Finished offline continual training.")

    for task in all_tasks:
        task2runner[task].close_env()

def evaluate_sequential(main_args, logger, task2runner):

    n_test_runs = max(1, main_args.test_nepisode // main_args.batch_size_run)
    with th.no_grad():
        for task in main_args.test_tasks:
            for _ in range(n_test_runs):
                task2runner[task].run(test_mode=True)
        if main_args.save_replay():
            task2runner[task].save_replay()
        
        task2runner[task].close_env()

    logger.log_stat("episode", 0, 0)
    logger.print_recent_stats()

def mt_train_baselines(train_tasks, main_args, logger, learner, task2args, task2runner, task2offline_buffer):
    t_env = 0
    episode = 0
    t_max = main_args.offline_train_steps
    model_save_time = t_env
    last_test_T = t_env
    last_log_T = t_env
    
    start_time = time.time()
    last_time = start_time
    test_time_total = 0

    batch_size_train = main_args.offline_batch_size
    batch_size_run = main_args.batch_size_run # num of parellel envs
    n_test_runs = max(1, main_args.test_nepisode//batch_size_run)
    
    test_start_time = time.time()
    with th.no_grad():
        for task in main_args.test_tasks:
            task2runner[task].t_env = t_env
            for _ in range(n_test_runs):
                task2runner[task].run(test_mode=True)
    test_time_total += time.time() - test_start_time

    while t_env < t_max:
        np.random.shuffle(train_tasks)
        for task in train_tasks:
            episode_sample = task2offline_buffer[task].sample(batch_size_train)

            if episode_sample.device != main_args.device:
                episode_sample.to(main_args.device)
            
            learner.train(episode_sample, t_env, episode, task)
            t_env += 1
            
            episode += batch_size_run

        if (t_env - last_test_T) / main_args.test_interval >= 1 or t_env >= t_max:
            test_start_time = time.time()
            
            with th.no_grad():
                for task in main_args.test_tasks:
                    task2runner[task].t_env = t_env
                    for _ in range(n_test_runs):
                        task2runner[task].run(test_mode=True)

            test_time_total += time.time() - test_start_time
            
            logger.console_logger.info("Step: {} / {}".format(t_env, t_max))
            logger.console_logger.info("Estimated time left: {}. Time passed: {}. Test time cost: {}".format(
                time_left(last_time, last_test_T, t_env, t_max), time_str(time.time() - start_time), time_str(test_time_total)
            ))
            last_time = time.time()
            last_test_T = t_env
        
        if main_args.save_model and (t_env - model_save_time >= main_args.save_model_interval or model_save_time == 0):
            save_path = os.path.join(main_args.save_dir, 'offline', str(t_env))
            os.makedirs(save_path, exist_ok=True)
            logger.console_logger.info("Saving models to {}".format(save_path))
            learner.save_models(save_path)
            model_save_time = t_env
        
        if (t_env - last_log_T) >= main_args.log_interval:
            last_log_T = t_env
            logger.log_stat("episode", episode, t_env)
            logger.print_recent_stats()
            
    # save the final model
    if main_args.save_model:
        save_path = os.path.join(main_args.save_dir, 'offline', str(t_max))
        os.makedirs(save_path, exist_ok=True)
        logger.console_logger.info("Saving final models to {}".format(save_path))
        learner.save_models(save_path)

def cont_train_baseline(train_tasks, main_args, logger, learner, task2args, task2runner, task2offline_buffer, parbuf=None):
    '''
    For offline learning based on multitask datasets
    '''
    t_env = 0
    episode = 0
    train_steps = main_args.cont_train_steps
    
    batch_size_train = main_args.offline_batch_size
    batch_size_run = main_args.batch_size_run # num of parellel envs
    
    for task in train_tasks:
        print("Train on task: ", task)
        learner.task2train_info[task]["log_stats_t"] = 0
        
        train_eval_epoch_Q(
            task,
            main_args,
            task2args,
            logger,
            learner,
            task2runner,
            task2offline_buffer,
            batch_size_train,
            train_steps,
            t_env,
            episode,
            parbuf=parbuf,
        )
        t_env += train_steps + 1
        episode += (train_steps + 1) * batch_size_train

        logger.log_stat("episode", episode, t_env)
        logger.print_recent_stats()
        th.cuda.empty_cache()
        
def cont_train_OWL(train_tasks, main_args, logger, learner, task2args, task2runner, task2offline_buffer, parbuf=None):
    t_env = 0
    episode = 0
    train_steps = main_args.cont_train_steps
    
    batch_size_train = main_args.offline_batch_size
    batch_size_run = main_args.batch_size_run # num of parellel envs
    n_test_runs = max(1, main_args.test_nepisode//batch_size_run)

    for task in train_tasks:
        print("Train on task: ", task)
        learner.task2train_info[task]["log_stats_t"] = 0
        
        if hasattr(learner, 'preset'):
            test_batch = task2offline_buffer[task].sample(batch_size_train)
            test_batch.to(main_args.device)
            learner.check_need_alloc(test_batch, task, t_env)
            learner.preset(task)
        
        train_eval_epoch_Q_OWL(
            task,
            main_args,
            task2args,
            logger,
            learner,
            task2runner,
            task2offline_buffer,
            batch_size_train,
            train_steps,
            t_env,
            episode,
            parbuf=parbuf,
        )
        
        if hasattr(learner, 'postset'):
            learner.postset(task)
        
        t_env += train_steps
        episode += train_steps * batch_size_train
        
        logger.log_stat("episode", episode, t_env)
        logger.print_recent_stats()
        th.cuda.empty_cache()
    
    if main_args.abla_z:
        import datetime
        time_stamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        skills = {'rho_pri': learner.rho_pri_records, 'rho_pos': learner.rho_pos_records}
        path = f'analysis-skill/skill_{time_stamp}.pkl'
        with open(path, 'wb') as f:
            pickle.dump(skills, f)
            print(f"Successfully saved to {path}")
        
def mt_train_ODIS(train_tasks, main_args, logger, learner, task2args, task2runner, task2offline_buffer, t_start=0, pretrain=False, test_task2offlinedata=None):
    ########## start training ##########
    t_env = t_start
    episode = 0  # episode does not matter
    t_max = main_args.offline_train_steps if not pretrain else main_args.pretrain_steps
    model_save_time = 0
    last_test_T = 0
    last_log_T = 0
    start_time = time.time()
    last_time = start_time
    test_time_total = 0
    test_start_time = 0

    # get some common information
    batch_size_train = main_args.batch_size
    batch_size_run = main_args.batch_size_run

    # do test before training
    n_test_runs = max(1, main_args.test_nepisode // batch_size_run)
    test_start_time = time.time()

    with th.no_grad():
        for task in main_args.test_tasks:
            task2runner[task].t_env = t_env
            for _ in range(n_test_runs):
                task2runner[task].run(test_mode=True, pretrain=pretrain)

        # test_pretrain for pretrained tasks
        if pretrain and test_task2offlinedata is not None:
            for task, data_buffer in test_task2offlinedata.items():
                episode_sample = data_buffer.sample(batch_size_train * 3)

                if episode_sample.device != task2args[task].device:
                    episode_sample.to(task2args[task].device)

                if hasattr(learner, 'test_pretrain'):
                    learner.test_pretrain(episode_sample, t_env, episode, task)
                else:
                    raise ValueError("Do test_pretrain with a learner that does not have a `test_pretrain` method!")

    test_time_total += time.time() - test_start_time

    while t_env < t_max:
        # shuffle tasks
        np.random.shuffle(train_tasks)
        # train each task
        for task in train_tasks:

            episode_sample = task2offline_buffer[task].sample(batch_size_train)

            if episode_sample.device != task2args[task].device:
                episode_sample.to(task2args[task].device)

            if pretrain:
                if hasattr(learner, 'pretrain'):
                    terminated = learner.pretrain(episode_sample, t_env, episode, task)
                else:
                    raise ValueError("Do pretraining with a learner that does not have a `pretrain` method!")
            else:
                terminated = learner.train(episode_sample, t_env, episode, task)

            if terminated is not None and terminated:
                break

            t_env += 1
            episode += batch_size_run

        learner.update(pretrain=pretrain)

        if terminated is not None and terminated:
            logger.console_logger.info(f"Terminate training by the learner at t_env = {t_env}. Finish training.")
            break

        # Execute test runs once in a while & final evaluation
        if (t_env - last_test_T) / main_args.test_interval >= 1 or t_env >= t_max:
            test_start_time = time.time()

            with th.no_grad():
                for task in main_args.test_tasks:
                    task2runner[task].t_env = t_env
                    for _ in range(n_test_runs):
                        task2runner[task].run(test_mode=True, pretrain=pretrain)

                # test_pretrain for pretrained tasks
                if pretrain and test_task2offlinedata is not None:
                    for task, data_buffer in test_task2offlinedata.items():
                        episode_sample = data_buffer.sample(batch_size_train * 10)

                        if episode_sample.device != task2args[task].device:
                            episode_sample.to(task2args[task].device)

                        if hasattr(learner, 'test_pretrain'):
                            learner.test_pretrain(episode_sample, t_env, episode, task)
                        else:
                            raise ValueError(
                                "Do test_pretrain with a learner that does not have a `test_pretrain` method!")

            test_time_total += time.time() - test_start_time

            logger.console_logger.info("Step: {} / {}".format(t_env, t_max))
            logger.console_logger.info("Estimated time left: {}. Time passed: {}. Test time cost: {}".format(
                time_left(last_time, last_test_T, t_env, t_max), time_str(time.time() - start_time),
                time_str(test_time_total)
            ))
            last_time = time.time()
            last_test_T = t_env

        if main_args.save_model and (t_env - model_save_time >= main_args.save_model_interval or model_save_time == 0):
            if pretrain:
                save_path = os.path.join(main_args.pretrain_save_dir, str(t_env))
            else:
                save_path = os.path.join(main_args.save_dir, str(t_env))
            os.makedirs(save_path, exist_ok=True)
            logger.console_logger.info("Saving models to {}".format(save_path))
            learner.save_models(save_path)
            model_save_time = t_env

        if (t_env - last_log_T) >= main_args.log_interval:
            last_log_T = t_env
            logger.log_stat("episode", episode, t_env)
            logger.print_recent_stats()
        
def train_eval_epoch_Q(task, main_args, task2args, logger, learner, task2runner, task2buffer, batch_size, train_steps=50000, t_env_start=0, episode_start=0, parbuf=None):
    t_env = t_env_start
    episode = episode_start
    t_max = t_env_start + train_steps
    model_save_time = t_env
    last_test_T = t_env
    last_log_T = t_env
    test_time_total = 0
    
    args = task2args[task]
    buffer = task2buffer[task]
    
    n_test_runs = args.test_nepisode
    rehearsal_bs = args.rehearsal_batch_size
                
    test_start_time = time.time()
    with th.no_grad():
        for tt in main_args.test_tasks:
            task2runner[tt].t_env = t_env
            for _ in range(n_test_runs):
                task2runner[tt].run(test_mode=True)
    test_time_total += time.time() - test_start_time
    
    logger.console_logger.info("Start Training")
    start_time = time.time()
    last_time = start_time
    while t_env < t_max:
        episode_sample = buffer.sample(batch_size)

        if episode_sample.device != args.device:
            episode_sample.to(args.device)
        
        if parbuf is not None and t_env >= train_steps:
            if t_env % 2 == 0:
                prec_sample, prec_task = parbuf.sample(rehearsal_bs, task)
                if prec_sample is not None and prec_sample.device != args.device:
                    prec_sample.to(args.device)
                learner.train(prec_sample, t_env, episode, prec_task)
            else:
                learner.train(episode_sample, t_env, episode, task)
        else:
            learner.train(episode_sample, t_env, episode, task)
        t_env += 1
        
        episode += batch_size

        if (t_env - last_test_T) / args.test_interval >= 1 or t_env >= t_max:
            test_start_time = time.time()
            
            with th.no_grad():
                for tt in main_args.test_tasks:
                    task2runner[tt].t_env = t_env
                    for _ in range(n_test_runs):
                        task2runner[tt].run(test_mode=True)
            test_time_total += time.time() - test_start_time
            
            logger.console_logger.info("Step: {} / {}".format(t_env, t_max))
            logger.console_logger.info("Estimated time left: {}. Time passed: {}. Test time cost: {}".format(
                time_left(last_time, last_test_T, t_env, t_max), time_str(time.time() - start_time), time_str(test_time_total)
            ))
            last_time = time.time()
            last_test_T = t_env
        
        if (t_env - last_log_T) >= args.log_interval:
            last_log_T = t_env
            logger.log_stat("episode", episode, t_env)
            logger.print_recent_stats()
            
def train_eval_epoch_Q_OWL(task, main_args, task2args, logger, learner, task2runner, task2buffer, batch_size, train_steps=50000, t_env_start=0, episode_start=0, parbuf=None):
    t_env = t_env_start
    episode = episode_start
    t_max = t_env_start + train_steps
    model_save_time = t_env
    last_test_T = t_env
    last_log_T = t_env
    test_time_total = 0
    
    args = task2args[task]
    buffer = task2buffer[task]
    
    val_bs = 32
    
    is_mh = hasattr(learner, 'switch_adaptor')
    
    n_test_runs = args.test_nepisode
    rehearsal_bs = args.rehearsal_batch_size
                
    test_start_time = time.time()
    with th.no_grad():
        for tt in main_args.test_tasks:
            if is_mh: learner.switch_adaptor(tt)
            task2runner[tt].t_env = t_env
            for _ in range(n_test_runs):
                task2runner[tt].run(test_mode=True)
                
            # val_sample = task2buffer[tt].val_sample(val_bs)
            # if val_sample is not None and val_sample.device != args.device:
            #     val_sample.to(args.device)
            # learner.val_ebm(val_sample, t_env, 0, tt) # NOTE only for testing EBM
            
    if is_mh: learner.switch_adaptor(task)
    test_time_total += time.time() - test_start_time
    
    logger.console_logger.info("Start Training")
    start_time = time.time()
    last_time = start_time
    while t_env < t_max:
        episode_sample = buffer.sample(batch_size)

        if episode_sample.device != args.device:
            episode_sample.to(args.device)
        
        if parbuf is not None and t_env >= train_steps:
            if t_env % 2 == 0:
                prec_sample, prec_task = parbuf.sample(rehearsal_bs, task)
                if prec_sample is not None and prec_sample.device != args.device:
                    prec_sample.to(args.device)
                learner.train(prec_sample, t_env, episode, prec_task)
            else:
                learner.train(episode_sample, t_env, episode, task)
        else:
            learner.train(episode_sample, t_env, episode, task)
        
        episode += batch_size

        if (t_env - last_test_T) / args.test_interval >= 1 or t_env >= t_max:
            test_start_time = time.time()
            
            with th.no_grad():
                for tt in main_args.test_tasks:
                    if is_mh: learner.switch_adaptor(tt)
                    task2runner[tt].t_env = t_env
                    for _ in range(n_test_runs):
                        task2runner[tt].run(test_mode=True)
            
                    # val_sample = task2buffer[tt].val_sample(val_bs)
                    # if val_sample is not None and val_sample.device != args.device:
                    #     val_sample.to(args.device)
                    # learner.val_ebm(val_sample, t_env, 0, tt) # NOTE only for testing EBM
            if is_mh: learner.switch_adaptor(task)
            test_time_total += time.time() - test_start_time
            
            logger.console_logger.info("Step: {} / {}".format(t_env, t_max))
            logger.console_logger.info("Estimated time left: {}. Time passed: {}. Test time cost: {}".format(
                time_left(last_time, last_test_T, t_env, t_max), time_str(time.time() - start_time), time_str(test_time_total)
            ))
            last_time = time.time()
            last_test_T = t_env
        
        if (t_env - last_log_T) >= args.log_interval:
            last_log_T = t_env
            logger.log_stat("episode", episode, t_env)
            logger.print_recent_stats()
            
        # print(t_env)
        if main_args.save_model and t_env - model_save_time >= main_args.save_model_interval:
            save_path = os.path.join(main_args.save_dir, 'model', str(t_env))
            os.makedirs(save_path, exist_ok=True)
            logger.console_logger.info("Saving models to {}".format(save_path))
            learner.save_models(save_path)
            model_save_time = t_env
            
        t_env += 1

def args_sanity_check(config, _log):

    # set CUDA flags
    # config["use_cuda"] = True # Use cuda whenever possible!
    if config["use_cuda"] and not th.cuda.is_available():
        config["use_cuda"] = False
        _log.warning("CUDA flag use_cuda was switched OFF automatically because no CUDA devices are available!")

    if config["test_nepisode"] < config["batch_size_run"]:
        config["test_nepisode"] = config["batch_size_run"]
    else:
        config["test_nepisode"] = (config["test_nepisode"]//config["batch_size_run"]) * config["batch_size_run"]

    return config
