import os
import pprint
import time
import threading
import torch as th
from types import SimpleNamespace as SN
from utils.logging import Logger
from utils.timehelper import time_left, time_str
import json
import copy
import random
import numpy as np

from learners import REGISTRY as le_REGISTRY
from runners import REGISTRY as r_REGISTRY
from controllers import REGISTRY as mac_REGISTRY
from components.episode_buffer import ReplayBuffer
from components.offline_buffer import OfflineBuffer
from components.transforms import OneHot

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.pyplot import MultipleLocator, colorbar
from matplotlib.ticker import FuncFormatter, FormatStrFormatter
import matplotlib


### TODO: 完成meta train的框架，主要需要加载多个任务的离线数据集，同时拥有多个环境实例以进行测试


def run(_run, _config, _log):

    # check args sanity
    print(_config["use_cuda"])
    _config = args_sanity_check(_config, _log)
    
    args = SN(**_config)
    args.device = "cuda" if args.use_cuda else "cpu"
    

    logger = Logger(_log)

    _log.info("Experiment Parameters:")
    experiment_params = pprint.pformat(_config,
                                       indent=4,
                                       width=1)
    _log.info("\n\n" + experiment_params + "\n")

    results_save_dir = args.results_save_dir
    if args.use_wandb:
        args.use_tensorboard = False
    
    if args.use_swanlab:
        args.use_tensorboard = False
    # assert args.use_tensorboard and args.use_wandb
    
    
    if args.use_tensorboard and not args.evaluate:
        # only log tensorboard when in training mode
        tb_exp_direc = os.path.join(results_save_dir, 'logs')
        logger.setup_tb(tb_exp_direc)
        
    
    args.wandb_project_name = "offpymarl_meta"
    if args.use_wandb and not args.evaluate:
        wandb_run_name = args.results_save_dir.split('/')
        wandb_run_name = "/".join(wandb_run_name[wandb_run_name.index("results")+1:])
        wandb_exp_direc = os.path.join(results_save_dir, 'logs')
        logger.setup_wandb(wandb_exp_direc, project=args.wandb_project_name, name=wandb_run_name,
                           run_id=args.resume_id, config=args)
    
    if args.use_swanlab and not args.evaluate:
        wandb_run_name = args.results_save_dir.split('/')
        wandb_run_name = "/".join(wandb_run_name[wandb_run_name.index("results")+1:])
        wandb_exp_direc = os.path.join(results_save_dir, 'logs')
        logger.setup_swanlab(wandb_exp_direc, project=args.wandb_project_name, name=wandb_run_name,
                           run_id=args.resume_id, config=args, config_dict=_config)
        
    # write config file
    config_str = json.dumps(vars(args), indent=4)
    with open(os.path.join(results_save_dir, "config.json"), "w") as f:
        f.write(config_str)
    # set model save dir
    args.model_save_dir = os.path.join(results_save_dir, 'models')
    args.vis_save_dir = os.path.join(results_save_dir, 'visulizations')

    # sacred is on by default
    logger.setup_sacred(_run)

    # Run and train
    run_sequential(args=args, logger=logger)

    logger.finish()

    # Clean up after finishing
    print("Exiting Main")

    print("Stopping all threads")
    for t in threading.enumerate():
        if t.name != "MainThread":
            print("Thread {} is alive! Is daemon: {}".format(t.name, t.daemon))
            t.join(timeout=1)
            print("Thread joined")

    print("Exiting script")

    # Making sure framework really exits
    os._exit(os.EX_OK)

def evaluate_sequential(args, runner):

    for _ in range(args.test_nepisode):
        runner.run(test_mode=True)

    if args.save_replay:
        runner.save_replay()
    
    runner.close_env()


def run_sequential(args, logger):
    # In offline training, we use t_max to denote iterations
    
    # Init runner so we can get env info

    train_tasks = args.train_task_ls
    args.n_tasks = len(train_tasks)
    main_args = copy.deepcopy(args)

    total_tasks = args.total_task_ls

    task2args, task2runner, task2buffer = {}, {}, {}
    task2scheme, task2groups, task2preprocess = {}, {}, {}
    for task in total_tasks:
        task_args = copy.deepcopy(args)
        if task_args.env == "sc2":
            task_args.env_args["map_name"] = task
        elif task_args.env == "sc2_v2":
            task_args.env_args["task"] = task
            if "terran" in task:
                task_args.env_args["map_name"] = "10gen_terran"
            elif "protoss" in task:
                task_args.env_args["map_name"] = "10gen_protoss"
            elif "zerg" in task:
                task_args.env_args["map_name"] = "10gen_zerg"
            else:
                assert False
        elif task_args.env == "mt_grid_mpe":
            task_args.env_args["task_id"] = task
        else:
            raise Exception(f"Unsupported env type {task_args.env}")
        task2args[task] = task_args
        task_runner = r_REGISTRY[task_args.runner](args=task_args, logger=logger, task=task)
        task2runner[task] = task_runner

        env_info = task_runner.get_env_info()

        for k, v in env_info.items():
            setattr(task_args, k, v)
        if "n_landmarks" in task_args.env_args.keys():
            if task_args.num_tasks==4:
                setattr(task_args, "n_landmarks", env_info["n_agents"])
            else:
                setattr(task_args, "n_landmarks", task_args.env_args["n_landmarks"])
        
        scheme = {
            "state": {"vshape": env_info["state_shape"]},
            "obs": {"vshape": env_info["obs_shape"], "group": "agents"},
            "actions": {"vshape": (1,), "group": "agents", "dtype": th.long},
            "avail_actions": {"vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.int},
            "reward": {"vshape": (1,)},
            "terminated": {"vshape": (1,), "dtype": th.uint8},
        }
        groups = {
            "agents": task_args.n_agents
        }
        preprocess = {
            "actions": ("actions_onehot", [OneHot(out_dim=task_args.n_actions)])
        }

        # task_buffer = ReplayBuffer(scheme, groups, 100, env_info["episode_limit"] + 1,
        #                     preprocess=preprocess,
        #                     device="cpu" if task_args.buffer_cpu_only else task_args.device)
        # task2buffer[task] = task_buffer
        if preprocess is not None:
            for k in preprocess:
                assert k in scheme
                new_k = preprocess[k][0]
                transforms = preprocess[k][1]

                vshape = scheme[k]["vshape"]
                dtype = scheme[k]["dtype"]
                for transform in transforms:
                    vshape, dtype = transform.infer_output_info(vshape, dtype)

                scheme[new_k] = {
                    "vshape": vshape,
                    "dtype": dtype
                }
                if "group" in scheme[k]:
                    scheme[new_k]["group"] = scheme[k]["group"]
                if "episode_const" in scheme[k]:
                    scheme[new_k]["episode_const"] = scheme[k]["episode_const"]
        
        # store task information
        task2scheme[task], task2groups[task], task2preprocess[task] = scheme, groups, preprocess

    # task2buffer_scheme = {
    #     task: task2buffer[task].scheme for task in total_tasks
    # }

    # Setup multiagent controller here
    mac = mac_REGISTRY[main_args.mac](total_tasks, task2scheme, task2args, main_args)

    for task in total_tasks:
        # setup runner
        task2runner[task].setup(scheme=task2scheme[task], groups=task2groups[task], preprocess=task2preprocess[task], mac=mac)

    # Learner
    learner = le_REGISTRY[main_args.learner](mac, logger, main_args)

    if main_args.use_cuda:
        print("use cuda")
        learner.cuda()
    
    if (main_args.is_encoder_train or main_args.is_role_encoder_train) and main_args.pretrain_id>=0:
        learner.load_models(main_args.pretrain_path_ls[main_args.pretrain_id])
    
    if getattr(main_args, "is_encoder_vis", False):
        learner.load_models(main_args.vis_path_ls[main_args.vis_model_id])
    
    if getattr(main_args, "is_role_encoder_vis", False):
        learner.load_models(main_args.vis_path_ls[main_args.vis_model_id])
    
    if getattr(main_args, "is_traj_vis", False):
        learner.load_models(main_args.policy_path_ls[main_args.algo_id][main_args.policy_id])
    
    if getattr(main_args, "is_meta_adaptation", False):
        learner.load_models(main_args.policy_path_ls[main_args.algo_id][main_args.policy_id])


    if main_args.checkpoint_path != "":
        raise Exception("We don't support checkpoint loading in multi-task learning currently!")
    
    # Create Offline Data
    task2offline_buffer = {}
    for id, task in enumerate(total_tasks):
        match task2args[task].env:
            case "sc2":
                task2args[task].map_name = task2args[task].env_args["map_name"]
            case "sc2_v2":
                task2args[task].map_name = task2args[task].env_args["task"]
            case "gymma":
                env_name, task2args[task].map_name = task2args[task].env_args['key'].split(':')
                task2args[task].env = env_name
            case "mt_grid_mpe":
                task2args[task].map_name = str(task2args[task].env_args["task_id"])
            case _:
                raise NotImplementedError("Do not support such envs: {}".format(task2args[task].env))
        

        max_buffer_size = task2args[task].offline_max_buffer_size
        if task not in train_tasks and task2args[task].env=="sc2_v2":
            max_buffer_size = task2args[task].offline_max_buffer_size_test
            
        offline_buffer = OfflineBuffer(task2args[task], task2args[task].map_name, task2args[task].offline_data_quality,
                                    main_args.offline_data_ls[id], max_buffer_size, 
                                    shuffle=task2args[task].offline_data_shuffle) # device defauly cpu
        task2offline_buffer[task] = offline_buffer
    
    role2offline_buffer = {}
    if main_args.role_data_root is not None and main_args.is_role_encoder_train:
        for role in main_args.role_ls:
            role2offline_buffer[role] = {}
            for task in main_args.role2task[role]:
                match task2args[task].env:
                    case "sc2":
                        task2args[task].map_name = task2args[task].env_args["map_name"]
                    case "sc2_v2":
                        task2args[task].map_name = task2args[task].env_args["task"]
                    case "gymma":
                        env_name, task2args[task].map_name = task2args[task].env_args['key'].split(':')
                        task2args[task].env = env_name
                    case "mt_grid_mpe":
                        task2args[task].map_name = str(task2args[task].env_args["task_id"])
                    case _:
                        raise NotImplementedError("Do not support such envs: {}".format(task2args[task].env))
                path = os.path.join(main_args.role_data_root, f"role{role}_task{task}")
                offline_buffer = OfflineBuffer(task2args[task], task2args[task].map_name, task2args[task].offline_data_quality,
                                            path, task2args[task].offline_max_buffer_size, 
                                            shuffle=task2args[task].offline_data_shuffle)
                role2offline_buffer[role][task] = offline_buffer
    
    
    logger.console_logger.info("Beginning offline training with {} iterations".format(main_args.t_max))
    if getattr(main_args, "is_traj_vis", False):
        visualize_trajectory(train_tasks, main_args, task2args, logger, learner, task2runner, task2offline_buffer)
    elif getattr(main_args, "is_role_encoder_vis", False):
        visualize_role_encoder(train_tasks, main_args, task2args, logger, learner, task2runner, task2offline_buffer, role2offline_buffer)
    elif getattr(main_args, "is_encoder_vis", False):
        visualize_meta_encoder(train_tasks, main_args, task2args, logger, learner, task2runner, task2offline_buffer)
    elif getattr(main_args, "is_meta_adaptation", False):
        meta_evaluate_sequential(train_tasks, main_args, task2args, logger, learner, task2runner, task2offline_buffer, task2scheme, task2groups, task2preprocess)
    elif getattr(main_args, "is_vae_train", False):
        train_vae(train_tasks, main_args, task2args, logger, learner, task2runner, task2offline_buffer)
    elif getattr(main_args, "is_odis", False):
        if getattr(main_args, "pretrain", False):
            train_odis(train_tasks, main_args, logger, learner, task2args, task2runner, task2offline_buffer,
                         pretrain=True, test_task2offlinedata=None)
        train_odis(train_tasks, main_args, logger, learner, task2args, task2runner, task2offline_buffer)
    elif main_args.is_prior_role_encoder_train:
        train_prior_role_encoder(train_tasks, main_args, task2args, logger, learner, task2runner, task2offline_buffer)
    elif main_args.is_role_encoder_train:
        train_role_encoder(train_tasks, main_args, task2args, logger, learner, task2runner, task2offline_buffer, role2offline_buffer)
    elif main_args.is_encoder_train:
        train_meta_encoder(train_tasks, main_args, task2args, logger, learner, task2runner, task2offline_buffer)
    else:
        train_sequential(train_tasks, main_args, task2args, logger, learner, task2runner, task2offline_buffer)

    if main_args.save_model:
        save_path = os.path.join(main_args.model_save_dir, str(main_args.t_max))
        os.makedirs(save_path, exist_ok=True)
        logger.console_logger.info("Save final model checkpoint in {}".format(save_path))
        learner.save_models(save_path)
    
    for task in total_tasks:
        task2runner[task].close_env()
    logger.console_logger.info("Finish Training")

task_target_ls = [
    [0, 1, 2],
    [0, 1, 3],
    [0, 1, 4],
    [0, 1, 5],
    [0, 2, 3],
    [0, 2, 4],
    [0, 2, 5],
    [0, 3, 4],
    [0, 3, 5],
    [0, 4, 5],
    [1, 2, 3],
    [1, 2, 4],
    [1, 2, 5],
    [1, 3, 4],
    [1, 3, 5],
    [1, 4, 5],
    [2, 3, 4],
    [2, 3, 5],
    [2, 4, 5],
    [3, 4, 5]
]

def draw_traj(main_args, traj, task, n, is_win):
    agent0_y = []
    agent0_x = []
    agent1_y = []
    agent1_x = []
    agent2_y = []
    agent2_x = []
    states = traj['state'].cpu().numpy()[0]
    for state in states:
        agent_0, agent_1, agent_2 = state[0:2], state[2:4], state[4:6]
        target_0 = state[6:8]
        if 0 in target_0:
            break
        agent0_y.append(agent_0[1])
        agent0_x.append(agent_0[0])
        agent1_y.append(agent_1[1])
        agent1_x.append(agent_1[0])
        agent2_y.append(agent_2[1])
        agent2_x.append(agent_2[0])
    state = states[0]
    target_0, target_1, target_2, target_3, target_4, target_5 = state[6:8], state[8:10], state[10:12], state[12:14], state[14:16], state[16:18]
    targets = [target_0, target_1, target_2, target_3, target_4, target_5]
    task_target = task_target_ls[task]
    for i in range(len(targets)):
        if i in task_target:
            plt.scatter(
                targets[i][0], targets[i][1],
                marker="*",
                s=200,
                color="red"
            )
        else:
            plt.scatter(
                targets[i][0], targets[i][1],
                marker="*",
                s=200,
                color="grey"
            )
    plt.plot(agent0_x,agent0_y)
    plt.plot(agent1_x,agent1_y)
    plt.plot(agent2_x,agent2_y)
    if not is_win:
        save_root = main_args.traj_save_root+f"/{task}/{n}"
    else:
        save_root = main_args.traj_save_root+f"/{task}/{n}_win"
    if not os.path.exists(save_root):
        os.makedirs(save_root)
    save_path = save_root+"/plot.png"
    np_save_path = save_root+"/plot_data.npz"
    plt.savefig(save_path)
    plt.clf()
    np.savez(np_save_path, agent0_x=agent0_x, agent0_y=agent0_y, agent1_x=agent1_x, agent1_y=agent1_y, agent2_x=agent2_x, agent2_y=agent2_y, targets=targets, task_target=task_target)



def visualize_trajectory(train_tasks, main_args, task2args, logger, learner, task2runner, task2offline_buffer):

    t_env = 0

    batch_size_run = main_args.batch_size_run # num of parellel envs
    n_test_runs = max(1, main_args.test_nepisode//batch_size_run)

    with th.no_grad():
        for task in main_args.test_task_ls:
            for n in range(n_test_runs):
                episode_sample = task2offline_buffer[task].sample(1)
                if episode_sample.device != main_args.device:
                    episode_sample.to(main_args.device)
                _, local_task_encoding = learner.get_task_encoding(episode_sample, task)
                if main_args.use_role_encoder:
                    role_encoding = learner.get_role_encoding(episode_sample, local_task_encoding, task)
                    if not main_args.only_role_encoding:
                        local_task_encoding = th.cat([local_task_encoding, role_encoding], dim=-1)
                    else:
                        local_task_encoding = role_encoding
                task2runner[task].t_env = t_env
                traj, return_mean, battle_won_mean, _, is_win = task2runner[task].run(local_task_encoding, test_mode=True, learner=learner, return_detail=True)
                draw_traj(main_args, traj, task, n, is_win)


    logger.console_logger.info("Finish training sequential")

def visualize_role_encoder(train_tasks, main_args, task2args, logger, learner, task2runner, task2offline_buffer, role2offline_buffer):
    episode = 0  # episode does not matter
    train_roles = main_args.role_ls
    role2task_ls = main_args.role2task

    role2episode_sample_vis = {}
    role2task_sample_vis = {}
    role2encoding_sample_vis = {}
    for role in train_roles:
        role2episode_sample_vis[role] = []
        role2task_sample_vis[role] = []
        role2encoding_sample_vis[role] = []
        task_ls = role2task_ls[role]
        for task in task_ls:
            role2task_sample_vis[role].append(task)
            # sample episode and encoding for visualization
            offline_buffer = role2offline_buffer[role][task]
            episode_sample = offline_buffer.sample(main_args.vis_batch_size)
            if episode_sample.device != main_args.device:
                episode_sample.to(main_args.device)
            role2episode_sample_vis[role].append(episode_sample)

            encoding_offline_buffer = task2offline_buffer[task]
            encoding_episode = encoding_offline_buffer.sample(main_args.vis_batch_size)
            if encoding_episode.device != main_args.device:
                encoding_episode.to(main_args.device)
            role2encoding_sample_vis[role].append(encoding_episode)            
    
    learner.get_visualize_data(train_roles, main_args, task2args, role2episode_sample_vis, role2encoding_sample_vis, role2task_sample_vis, episode)
    logger.console_logger.info("Finish training meta encoder")

def visualize_meta_encoder(train_tasks, main_args, task2args, logger, learner, task2runner, task2offline_buffer):
    episode = 0  # episode does not matter
    task2episode_sample_vis = {}
    for task in train_tasks:
        offline_buffer = task2offline_buffer[task]
        episode_sample = offline_buffer.sample(main_args.vis_batch_size)
        if episode_sample.device != main_args.device:
            episode_sample.to(main_args.device)
        # episode_sample_ls.append(episode_sample)
        task2episode_sample_vis[task] = episode_sample
    data_save_path = main_args.vis_data_save_root + f"/{main_args.vis_model_id}.npz"
    learner.get_visualize_local_data(train_tasks, main_args, task2args, task2episode_sample_vis, episode, data_save_path)

def meta_online_adaptation(task, seed, learner, runner, episode_num, scheme, groups, preprocess, n_agents, args, n_test_runs):
    env_info = runner.get_env_info()
    buffer = ReplayBuffer(scheme, groups, episode_num, env_info["episode_limit"] + 1,
                          preprocess=preprocess,
                          device="cpu" if args.buffer_cpu_only else args.device)
    prior_local_encoding = th.zeros(1, n_agents, args.encoding_dim, device=args.device)
    ep_returns, ep_won = [], []
    ep_returns_with_prior, ep_won_with_prior = [], []

    best_return = -1000
    best_won = 0
    best_episode = None
    context_ls = []

    for ep in range(episode_num):
        # runner.reset_env(seed=seed)
        if buffer.episodes_in_buffer == 0:
            local_task_encoding = prior_local_encoding
        else:
            if args.collect_with_idaq:
                 _, local_task_encoding = learner.get_task_encoding(context_ls[best_episode], task)
            else:
                episode_sample = buffer.sample(buffer.episodes_in_buffer)
                if episode_sample.device != args.device:
                    episode_sample.to(args.device)
                _, local_task_encoding = learner.get_task_encoding(episode_sample, task)
                local_task_encoding = local_task_encoding.mean(dim=0, keepdim=True)

        if not args.use_role_encoder:
            episode_batch, _, _, returns, battle_won = runner.run(local_task_encoding, test_mode=False, learner=learner, return_detail=True, deterministic=True)
        else:
            if best_won == 1 or (args.collect_with_best_role and ep>=args.prior_role_collect_episode):
                role_encoding = learner.get_role_encoding(context_ls[best_episode], local_task_encoding, task)
                # local_task_encoding = th.cat([local_task_encoding, role_encoding], dim=-1)
                if not args.only_role_encoding:
                    local_task_encoding = th.cat([local_task_encoding, role_encoding], dim=-1)
                else:
                    local_task_encoding = role_encoding
                episode_batch, _, _, returns, battle_won = runner.run(local_task_encoding, test_mode=False, learner=learner, return_detail=True, deterministic=True)
            else:
                episode_batch, _, _, returns, battle_won = runner.run(local_task_encoding, test_mode=False, run_with_prior_encoding=True, learner=learner, return_detail=True, deterministic=True)
        if episode_batch.device != args.device:
            episode_batch.to(args.device)
        context_ls.append(episode_batch)
        buffer.insert_episode_batch(episode_batch)
        if (returns > best_return and battle_won>=best_won) or battle_won > best_won:
            best_episode = ep
            best_return = returns
            best_won = battle_won
        
        # runner.close_env()
        
        ### eval
        # runner.reset_env(seed=seed)
        if args.use_idaq:
            _, local_task_encoding = learner.get_task_encoding(context_ls[best_episode], task)
        else:
            episode_sample = buffer.sample(buffer.episodes_in_buffer)
            if episode_sample.device != args.device:
                episode_sample.to(args.device)
            _, local_task_encoding = learner.get_task_encoding(episode_sample, task)
            local_task_encoding = local_task_encoding.mean(dim=0, keepdim=True)
        role_encoding = learner.get_role_encoding(context_ls[best_episode], local_task_encoding, task)
        if not args.only_role_encoding:
            local_task_encoding_with_role = th.cat([local_task_encoding, role_encoding], dim=-1)
        else:
            local_task_encoding_with_role = role_encoding
        if not args.use_role_encoder:
            for _ in range(n_test_runs):
                _, return_mean, battle_won_mean, returns, battle_won = runner.run(local_task_encoding, test_mode=True, learner=learner, return_detail=True)
        else:
            for _ in range(n_test_runs):
                _, return_mean, battle_won_mean, returns, battle_won = runner.run(local_task_encoding_with_role, test_mode=True, learner=learner, return_detail=True)
        ep_returns.append(return_mean)
        ep_won.append(battle_won_mean)
        # runner.close_env()

        # if not args.use_role_encoder:
        ep_returns_with_prior.append(return_mean)
        ep_won_with_prior.append(battle_won_mean)
        # else:
        #     # runner.reset_env(seed=seed)
        #     for _ in range(n_test_runs):
        #         _, return_mean, battle_won_mean, returns, battle_won = runner.run(local_task_encoding, test_mode=True, run_with_prior_encoding=True, learner=learner, return_detail=True)
        #     ep_returns_with_prior.append(return_mean)
        #     ep_won_with_prior.append(battle_won_mean)
            # runner.close_env()
    
    return np.array(ep_returns), np.array(ep_won), np.array(ep_returns_with_prior), np.array(ep_won_with_prior)

def meta_evaluate_sequential(train_tasks, main_args, task2args, logger, learner, task2runner, task2offline_buffer, task2scheme, task2groups, task2preprocess):

    t_env = 0
    episode = 0
    t_max = main_args.t_max
    model_save_time = 0
    last_test_T = 0
    last_log_T = 0
    start_time = time.time()
    last_time = start_time
    test_time_total = 0

    batch_size_train = main_args.offline_batch_size
    batch_size_run = main_args.batch_size_run # num of parellel envs
    n_test_runs = max(1, main_args.test_nepisode//batch_size_run)
    test_start_time = time.time()

    with th.no_grad():
        train_task_return_ls = None
        test_task_return_ls = None
        train_task_battle_won_ls = None
        test_task_battle_won_ls = None
        train_task_return_ls_with_prior = None
        test_task_return_ls_with_prior = None
        train_task_battle_won_ls_with_prior = None
        test_task_battle_won_ls_with_prior = None
        for task in main_args.train_task_ls:
            mean_ep_returns = None
            mean_ep_won = None
            mean_ep_returns_with_prior = None
            mean_ep_won_with_prior = None
            for _ in range(20):
                ep_returns, ep_won, ep_returns_with_prior, ep_won_with_prior = meta_online_adaptation(task, 0, learner, task2runner[task], 
                                                                                                        main_args.adaptation_episodes, task2scheme[task], 
                                                                                                        task2groups[task], task2preprocess[task], task2args[task].n_agents, 
                                                                                                        main_args, n_test_runs)
                if mean_ep_returns is None:
                    mean_ep_returns = ep_returns
                    mean_ep_won = ep_won
                    mean_ep_returns_with_prior = ep_returns_with_prior
                    mean_ep_won_with_prior = ep_won_with_prior
                else:
                    mean_ep_returns += ep_returns
                    mean_ep_won += ep_won
                    mean_ep_returns_with_prior += ep_returns_with_prior
                    mean_ep_won_with_prior += ep_won_with_prior
            
            mean_ep_returns /= 20
            mean_ep_won /= 20
            mean_ep_returns_with_prior /= 20
            mean_ep_won_with_prior /= 20

            for iters in range(main_args.adaptation_episodes):
                logger.log_stat(f"{task}/adaptation_return", mean_ep_returns[iters], iters)
                logger.log_stat(f"{task}/adaptation_win_rate", mean_ep_won[iters], iters)
                logger.log_stat(f"{task}/adaptation_return_with_prior", mean_ep_returns_with_prior[iters], iters)
                logger.log_stat(f"{task}/adaptation_win_rate_with_prior", mean_ep_won_with_prior[iters], iters)

            if train_task_return_ls is None:
                train_task_return_ls = mean_ep_returns
                train_task_battle_won_ls = mean_ep_won
                train_task_return_ls_with_prior = mean_ep_returns_with_prior
                train_task_battle_won_ls_with_prior = mean_ep_won_with_prior
            else:
                train_task_return_ls += mean_ep_returns
                train_task_battle_won_ls += mean_ep_won
                train_task_return_ls_with_prior += mean_ep_returns_with_prior
                train_task_battle_won_ls_with_prior += mean_ep_won_with_prior

        train_task_return_ls /= len(main_args.train_task_ls)
        train_task_battle_won_ls /= len(main_args.train_task_ls)
        train_task_return_ls_with_prior /= len(main_args.train_task_ls)
        train_task_battle_won_ls_with_prior /= len(main_args.train_task_ls)
            
        for task in main_args.test_task_ls:
            mean_ep_returns = None
            mean_ep_won = None
            mean_ep_returns_with_prior = None
            mean_ep_won_with_prior = None
            for _ in range(20):
                ep_returns, ep_won, ep_returns_with_prior, ep_won_with_prior = meta_online_adaptation(task, 0, learner, task2runner[task], 
                                                                                                        main_args.adaptation_episodes, task2scheme[task], 
                                                                                                        task2groups[task], task2preprocess[task], task2args[task].n_agents, 
                                                                                                        main_args, n_test_runs)
                if mean_ep_returns is None:
                    mean_ep_returns = ep_returns
                    mean_ep_won = ep_won
                    mean_ep_returns_with_prior = ep_returns_with_prior
                    mean_ep_won_with_prior = ep_won_with_prior
                else:
                    mean_ep_returns += ep_returns
                    mean_ep_won += ep_won
                    mean_ep_returns_with_prior += ep_returns_with_prior
                    mean_ep_won_with_prior += ep_won_with_prior
            
            mean_ep_returns /= 20
            mean_ep_won /= 20
            mean_ep_returns_with_prior /= 20
            mean_ep_won_with_prior /= 20

            for iters in range(main_args.adaptation_episodes):
                logger.log_stat(f"{task}/adaptation_return", mean_ep_returns[iters], iters)
                logger.log_stat(f"{task}/adaptation_win_rate", mean_ep_won[iters], iters)
                logger.log_stat(f"{task}/adaptation_return_with_prior", mean_ep_returns_with_prior[iters], iters)
                logger.log_stat(f"{task}/adaptation_win_rate_with_prior", mean_ep_won_with_prior[iters], iters)

            if test_task_return_ls is None:
                test_task_return_ls = mean_ep_returns
                test_task_battle_won_ls = mean_ep_won
                test_task_return_ls_with_prior = mean_ep_returns_with_prior
                test_task_battle_won_ls_with_prior = mean_ep_won_with_prior
            else:
                test_task_return_ls += mean_ep_returns
                test_task_battle_won_ls += mean_ep_won
                test_task_return_ls_with_prior += mean_ep_returns_with_prior
                test_task_battle_won_ls_with_prior += mean_ep_won_with_prior

        test_task_return_ls /= len(main_args.test_task_ls)
        test_task_battle_won_ls /= len(main_args.test_task_ls)
        test_task_return_ls_with_prior /= len(main_args.test_task_ls)
        test_task_battle_won_ls_with_prior /= len(main_args.test_task_ls)

        for iters in range(main_args.adaptation_episodes):
            logger.log_stat("train_task_return_mean", train_task_return_ls[iters], iters)
            logger.log_stat("train_task_battle_won_mean", train_task_battle_won_ls[iters], iters)
            logger.log_stat("test_task_return_mean", test_task_return_ls[iters], iters)
            logger.log_stat("test_task_battle_won_mean", test_task_battle_won_ls[iters], iters)
            logger.log_stat("train_task_return_mean_with_prior", train_task_return_ls_with_prior[iters], iters)
            logger.log_stat("train_task_battle_won_mean_with_prior", train_task_battle_won_ls_with_prior[iters], iters)
            logger.log_stat("test_task_return_mean_with_prior", test_task_return_ls_with_prior[iters], iters)
            logger.log_stat("test_task_battle_won_mean_with_prior", test_task_battle_won_ls_with_prior[iters], iters)
    test_time_total += time.time() - test_start_time
    time.sleep(100)
    logger.console_logger.info("Finish adaptation sequential")

def train_odis(train_tasks, main_args, logger, learner, task2args, task2runner, task2offlinedata, t_start=0,
                     pretrain=False, test_task2offlinedata=None):
    ########## start training ##########
    t_env = t_start
    episode = 0  # episode does not matter
    t_max = main_args.t_max if not pretrain else main_args.pretrain_steps
    model_save_time = 0
    last_test_T = 0
    last_log_T = 0
    start_time = time.time()
    last_time = start_time
    test_time_total = 0
    test_start_time = 0

    # get some common information
    batch_size_train = main_args.batch_size
    batch_size_run = main_args.batch_size_run

    # do test before training
    n_test_runs = max(1, main_args.test_nepisode // batch_size_run)
    test_start_time = time.time()

    with th.no_grad():
        train_task_return_ls = []
        test_task_return_ls = []
        train_task_battle_won_ls = []
        test_task_battle_won_ls = []
        for task in main_args.train_task_ls:
            task2runner[task].t_env = t_env
            for _ in range(n_test_runs):
                _, return_mean, battle_won_mean = task2runner[task].run(test_mode=True, pretrain=pretrain)
                if return_mean is not None:
                    train_task_return_ls.append(return_mean)
                    train_task_battle_won_ls.append(battle_won_mean)
        for task in main_args.test_task_ls:
            task2runner[task].t_env = t_env
            for _ in range(n_test_runs):
                _, return_mean, battle_won_mean = task2runner[task].run(test_mode=True, pretrain=pretrain)
                if return_mean is not None:
                    test_task_return_ls.append(return_mean)
                    test_task_battle_won_ls.append(battle_won_mean)
        
        train_task_return_mean = sum(train_task_return_ls) / len(train_task_return_ls)
        test_task_return_mean = sum(test_task_return_ls) / len(test_task_return_ls)
        train_task_battle_won_mean = sum(train_task_battle_won_ls) / len(train_task_battle_won_ls)
        test_task_battle_won_mean = sum(test_task_battle_won_ls) / len(test_task_battle_won_ls)
        logger.log_stat("train_task_return_mean", train_task_return_mean, t_env)
        logger.log_stat("train_task_battle_won_mean", train_task_battle_won_mean, t_env)
        logger.log_stat("test_task_return_mean", test_task_return_mean, t_env)
        logger.log_stat("test_task_battle_won_mean", test_task_battle_won_mean, t_env)



    test_time_total += time.time() - test_start_time

    while t_env < t_max:
        # shuffle tasks
        np.random.shuffle(train_tasks)
        # train each task
        for task in train_tasks:

            episode_sample = task2offlinedata[task].sample(batch_size_train)

            if episode_sample.device != task2args[task].device:
                episode_sample.to(task2args[task].device)

            if pretrain:
                if hasattr(learner, 'pretrain'):
                    terminated = learner.pretrain(episode_sample, t_env, episode, task)
                else:
                    raise ValueError("Do pretraining with a learner that does not have a `pretrain` method!")
            else:
                terminated = learner.train(episode_sample, t_env, episode, task)

            if terminated is not None and terminated:
                break

            t_env += 1
            episode += batch_size_run

        learner.update(pretrain=pretrain)

        if terminated is not None and terminated:
            logger.console_logger.info(f"Terminate training by the learner at t_env = {t_env}. Finish training.")
            break

        # Execute test runs once in a while & final evaluation
        if (t_env - last_test_T) / main_args.test_interval >= 1 or t_env >= t_max:
            test_start_time = time.time()

            with th.no_grad():
                train_task_return_ls = []
                test_task_return_ls = []
                train_task_battle_won_ls = []
                test_task_battle_won_ls = []
                for task in main_args.train_task_ls:
                    task2runner[task].t_env = t_env
                    for _ in range(n_test_runs):
                        _, return_mean, battle_won_mean = task2runner[task].run(test_mode=True, pretrain=pretrain)
                        if return_mean is not None:
                            train_task_return_ls.append(return_mean)
                            train_task_battle_won_ls.append(battle_won_mean)
                for task in main_args.test_task_ls:
                    task2runner[task].t_env = t_env
                    for _ in range(n_test_runs):
                        _, return_mean, battle_won_mean = task2runner[task].run(test_mode=True, pretrain=pretrain)
                        if return_mean is not None:
                            test_task_return_ls.append(return_mean)
                            test_task_battle_won_ls.append(battle_won_mean)
                
                train_task_return_mean = sum(train_task_return_ls) / len(train_task_return_ls)
                test_task_return_mean = sum(test_task_return_ls) / len(test_task_return_ls)
                train_task_battle_won_mean = sum(train_task_battle_won_ls) / len(train_task_battle_won_ls)
                test_task_battle_won_mean = sum(test_task_battle_won_ls) / len(test_task_battle_won_ls)
                logger.log_stat("train_task_return_mean", train_task_return_mean, t_env)
                logger.log_stat("train_task_battle_won_mean", train_task_battle_won_mean, t_env)
                logger.log_stat("test_task_return_mean", test_task_return_mean, t_env)
                logger.log_stat("test_task_battle_won_mean", test_task_battle_won_mean, t_env)

            test_time_total += time.time() - test_start_time

            logger.console_logger.info("Step: {} / {}".format(t_env, t_max))
            logger.console_logger.info("Estimated time left: {}. Time passed: {}. Test time cost: {}".format(
                time_left(last_time, last_test_T, t_env, t_max), time_str(time.time() - start_time),
                time_str(test_time_total)
            ))
            last_time = time.time()
            last_test_T = t_env

        if main_args.save_model and (t_env - model_save_time >= main_args.save_model_interval or model_save_time == 0):
            if pretrain:
                save_path = os.path.join(main_args.pretrain_save_dir, str(t_env))
            else:
                save_path = os.path.join(main_args.model_save_dir, str(t_env))
            os.makedirs(save_path, exist_ok=True)
            logger.console_logger.info("Saving models to {}".format(save_path))
            learner.save_models(save_path)
            model_save_time = t_env

        if (t_env - last_log_T) >= main_args.log_interval:
            last_log_T = t_env
            logger.log_stat("episode", episode, t_env)
            logger.print_recent_stats()

def train_vae(train_tasks, main_args, task2args, logger, learner, task2runner, task2offline_buffer):

    t_env = 0
    episode = 0
    t_max = main_args.t_max
    model_save_time = 0
    last_test_T = 0
    last_log_T = 0
    start_time = time.time()
    last_time = start_time
    test_time_total = 0

    batch_size_train = main_args.offline_batch_size
    batch_size_run = main_args.batch_size_run # num of parellel envs
    n_test_runs = max(1, main_args.test_nepisode//batch_size_run)
    test_start_time = time.time()

    while t_env < t_max:
        for task in main_args.train_task_ls:
            episode_sample = task2offline_buffer[task].sample(batch_size_train)
            if episode_sample.device != main_args.device:
                episode_sample.to(main_args.device)
            learner.train(episode_sample, t_env, task)
            
        t_env += 1
        episode += batch_size_run
        # Execute test runs once in a while & final evaluation
        
        if main_args.save_model and (t_env-model_save_time >= main_args.save_model_interval or model_save_time==0):
            save_path = os.path.join(main_args.model_save_dir, str(t_env))
            os.makedirs(save_path, exist_ok=True)
            logger.console_logger.info("Saving models to {}".format(save_path))
            learner.save_models(save_path)
            model_save_time = t_env
        
        if (t_env - last_log_T) >= main_args.log_interval:
            last_log_T = t_env
            logger.log_stat("episode", episode, t_env)
            logger.print_recent_stats()

    logger.console_logger.info("Finish training sequential")

def train_prior_role_encoder(train_tasks, main_args, task2args, logger, learner, task2runner, task2offline_buffer):
    episode = 0  # episode does not matter
    episode_max = main_args.encoder_train_episode
    model_save_time = 0
    last_log_T = 0
    last_vis_T = - main_args.encoder_vis_interval
    start_time = time.time()
    last_time = start_time
    train_roles = main_args.role_ls
    role2task_ls = main_args.role2task

    meta_batch_size = main_args.meta_batch_size
    while episode < episode_max:
        task2episode_sample = {}
        task2encoding_sample = {}
        
        for task in train_tasks:
            offline_buffer = task2offline_buffer[task]
            episode_sample = offline_buffer.sample(meta_batch_size)
            if episode_sample.device != main_args.device:
                episode_sample.to(main_args.device)
            task2episode_sample[task] = episode_sample

            encoding_sample = offline_buffer.sample(meta_batch_size)
            if encoding_sample.device != main_args.device:
                encoding_sample.to(main_args.device)
            task2encoding_sample[task] = encoding_sample

        
        learner.train(train_tasks, main_args, task2args, task2episode_sample, task2encoding_sample, episode)
        episode += 1

        ### 暂时没有test，之后可以考虑加上一个tsne可视化作为test过程

        if main_args.save_model and (episode-model_save_time >= main_args.encoder_save_model_interval or model_save_time==0):
            save_path = os.path.join(main_args.model_save_dir, str(episode))
            os.makedirs(save_path, exist_ok=True)
            logger.console_logger.info("Saving models to {}".format(save_path))
            learner.save_models(save_path)
            model_save_time = episode
        
        if (episode - last_log_T) >= main_args.encoder_log_interval:
            last_log_T = episode
            logger.log_stat("episode", episode, episode)
            logger.print_recent_stats()
    
    logger.console_logger.info("Finish training meta encoder")

def train_role_encoder(train_tasks, main_args, task2args, logger, learner, task2runner, task2offline_buffer, role2offline_buffer):
    episode = 0  # episode does not matter
    episode_max = main_args.encoder_train_episode
    model_save_time = 0
    last_log_T = 0
    last_vis_T = - main_args.encoder_vis_interval
    start_time = time.time()
    last_time = start_time
    train_roles = main_args.role_ls
    role2task_ls = main_args.role2task

    meta_batch_size = main_args.meta_batch_size
    while episode < episode_max:
        role2episode_sample = {}
        role2episode_positive = {}
        role2episode_negative = {}

        role2task_sample = {}
        role2task_positive = {}
        role2task_negative = {}

        role2encoding_sample = {}
        role2encoding_positive = {}
        role2encoding_negative = {}
        
        for role in train_roles:
            task_ls = role2task_ls[role]

            task_sample = random.choice(task_ls)
            role2task_sample[role] = task_sample
            offline_buffer = role2offline_buffer[role][task_sample]
            episode_sample = offline_buffer.sample(meta_batch_size)
            if episode_sample.device != main_args.device:
                episode_sample.to(main_args.device)
            role2episode_sample[role] = episode_sample

            encoding_offline_buffer = task2offline_buffer[task_sample]
            encoding_episode = encoding_offline_buffer.sample(meta_batch_size)
            if encoding_episode.device != main_args.device:
                encoding_episode.to(main_args.device)
            role2encoding_sample[role] = encoding_episode

            task_sample_positive = random.choice(task_ls)
            role2task_positive[role] = task_sample_positive
            offline_buffer_positive = role2offline_buffer[role][task_sample_positive]
            episode_positive = offline_buffer_positive.sample(meta_batch_size)
            if episode_positive.device != main_args.device:
                episode_positive.to(main_args.device)
            role2episode_positive[role] = episode_positive

            encoding_offline_buffer_positive = task2offline_buffer[task_sample_positive]
            encoding_episode_positive = encoding_offline_buffer_positive.sample(meta_batch_size)
            if encoding_episode_positive.device != main_args.device:
                encoding_episode_positive.to(main_args.device)
            role2encoding_positive[role] = encoding_episode_positive

            role2episode_negative[role] = []
            role2task_negative[role] = []
            role2encoding_negative[role] = []

            for new_role in train_roles:
                if new_role == role:
                    continue
                task_sample_negative = random.choice(role2task_ls[new_role])
                role2task_negative[role].append(task_sample_negative)

                offline_buffer_negative = role2offline_buffer[new_role][task_sample_negative]
                episode_negative = offline_buffer_negative.sample(meta_batch_size)
                if episode_negative.device != main_args.device:
                    episode_negative.to(main_args.device)
                role2episode_negative[role].append(episode_negative)

                encoding_offline_buffer_negative = task2offline_buffer[task_sample_negative]
                encoding_episode_negative = encoding_offline_buffer_negative.sample(meta_batch_size)
                if encoding_episode_negative.device != main_args.device:
                    encoding_episode_negative.to(main_args.device)
                role2encoding_negative[role].append(encoding_episode_negative)
        
        learner.train(train_roles, main_args, task2args, role2episode_sample, role2episode_positive, role2episode_negative, role2encoding_sample, role2encoding_positive, role2encoding_negative, role2task_sample, role2task_positive, role2task_negative, episode)
        episode += 1

        ### 暂时没有test，之后可以考虑加上一个tsne可视化作为test过程

        if main_args.save_model and (episode-model_save_time >= main_args.encoder_save_model_interval or model_save_time==0):
            save_path = os.path.join(main_args.model_save_dir, str(episode))
            os.makedirs(save_path, exist_ok=True)
            logger.console_logger.info("Saving models to {}".format(save_path))
            learner.save_models(save_path)
            model_save_time = episode
        
        if (episode - last_vis_T) >= main_args.encoder_vis_interval or episode==episode_max:
            last_vis_T = episode
            role2episode_sample_vis = {}
            role2task_sample_vis = {}
            role2encoding_sample_vis = {}
            for role in train_roles:
                role2episode_sample_vis[role] = []
                role2task_sample_vis[role] = []
                role2encoding_sample_vis[role] = []
                task_ls = role2task_ls[role]
                for task in task_ls:
                    role2task_sample_vis[role].append(task)
                    # sample episode and encoding for visualization
                    offline_buffer = role2offline_buffer[role][task]
                    episode_sample = offline_buffer.sample(main_args.vis_batch_size)
                    if episode_sample.device != main_args.device:
                        episode_sample.to(main_args.device)
                    role2episode_sample_vis[role].append(episode_sample)

                    encoding_offline_buffer = task2offline_buffer[task]
                    encoding_episode = encoding_offline_buffer.sample(main_args.vis_batch_size)
                    if encoding_episode.device != main_args.device:
                        encoding_episode.to(main_args.device)
                    role2encoding_sample_vis[role].append(encoding_episode)            
            
            learner.visualize(train_roles, main_args, task2args, role2episode_sample_vis, role2encoding_sample_vis, role2task_sample_vis, episode)

        
        if (episode - last_log_T) >= main_args.encoder_log_interval:
            last_log_T = episode
            logger.log_stat("episode", episode, episode)
            logger.print_recent_stats()
    
    logger.console_logger.info("Finish training meta encoder")


def train_meta_encoder(train_tasks, main_args, task2args, logger, learner, task2runner, task2offline_buffer):
    episode = 0  # episode does not matter
    episode_max = main_args.encoder_train_episode
    model_save_time = 0
    last_log_T = 0
    last_vis_T = - main_args.encoder_vis_interval
    start_time = time.time()
    last_time = start_time

    meta_batch_size = main_args.meta_batch_size
    while episode < episode_max:
        task2episode_sample = {}
        for task in train_tasks:
            offline_buffer = task2offline_buffer[task]
            episode_sample = offline_buffer.sample(meta_batch_size)
            if episode_sample.device != main_args.device:
                episode_sample.to(main_args.device)
            # episode_sample_ls.append(episode_sample)
            task2episode_sample[task] = episode_sample
        
        learner.train(train_tasks, main_args, task2args, task2episode_sample, episode)
        episode += 1

        ### 暂时没有test，之后可以考虑加上一个tsne可视化作为test过程

        if main_args.save_model and (episode-model_save_time >= main_args.encoder_save_model_interval or model_save_time==0):
            save_path = os.path.join(main_args.model_save_dir, str(episode))
            os.makedirs(save_path, exist_ok=True)
            logger.console_logger.info("Saving models to {}".format(save_path))
            learner.save_models(save_path)
            model_save_time = episode
        
        if (episode - last_vis_T) >= main_args.encoder_vis_interval or episode==episode_max:
            last_vis_T = episode
            task2episode_sample_vis = {}
            for task in train_tasks:
                offline_buffer = task2offline_buffer[task]
                episode_sample = offline_buffer.sample(main_args.vis_batch_size)
                if episode_sample.device != main_args.device:
                    episode_sample.to(main_args.device)
                # episode_sample_ls.append(episode_sample)
                task2episode_sample_vis[task] = episode_sample
            learner.visualize(train_tasks, main_args, task2args, task2episode_sample_vis, episode)

        
        if (episode - last_log_T) >= main_args.encoder_log_interval:
            last_log_T = episode
            logger.log_stat("episode", episode, episode)
            logger.print_recent_stats()
    
    logger.console_logger.info("Finish training meta encoder")


def train_sequential(train_tasks, main_args, task2args, logger, learner, task2runner, task2offline_buffer):

    t_env = 0
    episode = 0
    t_max = main_args.t_max
    model_save_time = 0
    last_test_T = 0
    last_log_T = 0
    start_time = time.time()
    last_time = start_time
    test_time_total = 0

    batch_size_train = main_args.offline_batch_size
    batch_size_run = main_args.batch_size_run # num of parellel envs
    n_test_runs = max(1, main_args.test_nepisode//batch_size_run)
    test_start_time = time.time()

    with th.no_grad():
        train_task_return_ls = []
        test_task_return_ls = []
        train_task_battle_won_ls = []
        test_task_battle_won_ls = []
        for task in main_args.train_task_ls:
            episode_sample = task2offline_buffer[task].sample(1)
            if episode_sample.device != main_args.device:
                episode_sample.to(main_args.device)
            _, local_task_encoding = learner.get_task_encoding(episode_sample, task)
            if main_args.use_role_encoder:
                role_encoding = learner.get_role_encoding(episode_sample, local_task_encoding, task)
                if not main_args.only_role_encoding:
                    local_task_encoding = th.cat([local_task_encoding, role_encoding], dim=-1)
                else:
                    local_task_encoding = role_encoding
            task2runner[task].t_env = t_env
            for _ in range(n_test_runs):
                _, return_mean, battle_won_mean = task2runner[task].run(local_task_encoding, test_mode=True, learner=learner)
                if return_mean is not None:
                    train_task_return_ls.append(return_mean)
                    train_task_battle_won_ls.append(battle_won_mean)
        for task in main_args.test_task_ls:
            episode_sample = task2offline_buffer[task].sample(1)
            if episode_sample.device != main_args.device:
                episode_sample.to(main_args.device)
            _, local_task_encoding = learner.get_task_encoding(episode_sample, task)
            if main_args.use_role_encoder:
                role_encoding = learner.get_role_encoding(episode_sample, local_task_encoding, task)
                if not main_args.only_role_encoding:
                    local_task_encoding = th.cat([local_task_encoding, role_encoding], dim=-1)
                else:
                    local_task_encoding = role_encoding
            task2runner[task].t_env = t_env
            for _ in range(n_test_runs):
                _, return_mean, battle_won_mean = task2runner[task].run(local_task_encoding, test_mode=True, learner=learner)
                if return_mean is not None:
                    test_task_return_ls.append(return_mean)
                    test_task_battle_won_ls.append(battle_won_mean)
        train_task_return_mean = sum(train_task_return_ls) / len(train_task_return_ls)
        test_task_return_mean = sum(test_task_return_ls) / len(test_task_return_ls)
        train_task_battle_won_mean = sum(train_task_battle_won_ls) / len(train_task_battle_won_ls)
        test_task_battle_won_mean = sum(test_task_battle_won_ls) / len(test_task_battle_won_ls)
        logger.log_stat("train_task_return_mean", train_task_return_mean, t_env)
        logger.log_stat("train_task_battle_won_mean", train_task_battle_won_mean, t_env)
        logger.log_stat("test_task_return_mean", test_task_return_mean, t_env)
        logger.log_stat("test_task_battle_won_mean", test_task_battle_won_mean, t_env)

        if main_args.use_role_encoder:
            train_task_return_ls = []
            test_task_return_ls = []
            train_task_battle_won_ls = []
            test_task_battle_won_ls = []
            for task in main_args.train_task_ls:
                episode_sample = task2offline_buffer[task].sample(1)
                if episode_sample.device != main_args.device:
                    episode_sample.to(main_args.device)
                _, local_task_encoding = learner.get_task_encoding(episode_sample, task)
                task2runner[task].t_env = t_env
                for _ in range(n_test_runs):
                    _, return_mean, battle_won_mean = task2runner[task].run(local_task_encoding, test_mode=True, run_with_prior_encoding=True, learner=learner)
                    if return_mean is not None:
                        train_task_return_ls.append(return_mean)
                        train_task_battle_won_ls.append(battle_won_mean)
            for task in main_args.test_task_ls:
                episode_sample = task2offline_buffer[task].sample(1)
                if episode_sample.device != main_args.device:
                    episode_sample.to(main_args.device)
                _, local_task_encoding = learner.get_task_encoding(episode_sample, task)
                task2runner[task].t_env = t_env
                for _ in range(n_test_runs):
                    _, return_mean, battle_won_mean = task2runner[task].run(local_task_encoding, test_mode=True, run_with_prior_encoding=True, learner=learner)
                    if return_mean is not None:
                        test_task_return_ls.append(return_mean)
                        test_task_battle_won_ls.append(battle_won_mean)
            train_task_return_mean = sum(train_task_return_ls) / len(train_task_return_ls)
            test_task_return_mean = sum(test_task_return_ls) / len(test_task_return_ls)
            train_task_battle_won_mean = sum(train_task_battle_won_ls) / len(train_task_battle_won_ls)
            test_task_battle_won_mean = sum(test_task_battle_won_ls) / len(test_task_battle_won_ls)
            logger.log_stat("train_task_return_mean_prior_role", train_task_return_mean, t_env)
            logger.log_stat("train_task_battle_won_mean_prior_role", train_task_battle_won_mean, t_env)
            logger.log_stat("test_task_return_mean_prior_role", test_task_return_mean, t_env)
            logger.log_stat("test_task_battle_won_mean_prior_role", test_task_battle_won_mean, t_env)

    
    test_time_total += time.time() - test_start_time

    while t_env < t_max:
        for task in main_args.train_task_ls:
            episode_sample = task2offline_buffer[task].sample(batch_size_train)
            episode_sample_encoding = task2offline_buffer[task].sample(batch_size_train)
            if episode_sample.device != main_args.device:
                episode_sample.to(main_args.device)
            if episode_sample_encoding.device != main_args.device:
                episode_sample_encoding.to(main_args.device)
            learner.train(episode_sample, episode_sample_encoding, t_env, episode, task)
            
        t_env += 1
        episode += batch_size_run
        # Execute test runs once in a while & final evaluation

        if (t_env - last_test_T) / main_args.test_interval >= 1 or t_env >= t_max:
            test_start_time = time.time()
            with th.no_grad():
                train_task_return_ls = []
                test_task_return_ls = []
                train_task_battle_won_ls = []
                test_task_battle_won_ls = []
                for task in main_args.train_task_ls:
                    episode_sample = task2offline_buffer[task].sample(1)
                    if episode_sample.device != main_args.device:
                        episode_sample.to(main_args.device)
                    _, local_task_encoding = learner.get_task_encoding(episode_sample, task)
                    if main_args.use_role_encoder:
                        role_encoding = learner.get_role_encoding(episode_sample, local_task_encoding, task)
                        if not main_args.only_role_encoding:
                            local_task_encoding = th.cat([local_task_encoding, role_encoding], dim=-1)
                        else:
                            local_task_encoding = role_encoding
                    task2runner[task].t_env = t_env
                    for _ in range(n_test_runs):
                        _, return_mean, battle_won_mean = task2runner[task].run(local_task_encoding, test_mode=True, learner=learner)
                        if return_mean is not None:
                            train_task_return_ls.append(return_mean)
                            train_task_battle_won_ls.append(battle_won_mean)
                for task in main_args.test_task_ls:
                    episode_sample = task2offline_buffer[task].sample(1)
                    if episode_sample.device != main_args.device:
                        episode_sample.to(main_args.device)
                    _, local_task_encoding = learner.get_task_encoding(episode_sample, task)
                    if main_args.use_role_encoder:
                        role_encoding = learner.get_role_encoding(episode_sample, local_task_encoding, task)
                        if not main_args.only_role_encoding:
                            local_task_encoding = th.cat([local_task_encoding, role_encoding], dim=-1)
                        else:
                            local_task_encoding = role_encoding
                    task2runner[task].t_env = t_env
                    for _ in range(n_test_runs):
                        _, return_mean, battle_won_mean = task2runner[task].run(local_task_encoding, test_mode=True, learner=learner)
                        if return_mean is not None:
                            test_task_return_ls.append(return_mean)
                            test_task_battle_won_ls.append(battle_won_mean)
                train_task_return_mean = sum(train_task_return_ls) / len(train_task_return_ls)
                test_task_return_mean = sum(test_task_return_ls) / len(test_task_return_ls)
                train_task_battle_won_mean = sum(train_task_battle_won_ls) / len(train_task_battle_won_ls)
                test_task_battle_won_mean = sum(test_task_battle_won_ls) / len(test_task_battle_won_ls)
                logger.log_stat("train_task_return_mean", train_task_return_mean, t_env)
                logger.log_stat("train_task_battle_won_mean", train_task_battle_won_mean, t_env)
                logger.log_stat("test_task_return_mean", test_task_return_mean, t_env)
                logger.log_stat("test_task_battle_won_mean", test_task_battle_won_mean, t_env)

                if main_args.use_role_encoder:
                    train_task_return_ls = []
                    test_task_return_ls = []
                    train_task_battle_won_ls = []
                    test_task_battle_won_ls = []
                    for task in main_args.train_task_ls:
                        episode_sample = task2offline_buffer[task].sample(1)
                        if episode_sample.device != main_args.device:
                            episode_sample.to(main_args.device)
                        _, local_task_encoding = learner.get_task_encoding(episode_sample, task)
                        task2runner[task].t_env = t_env
                        for _ in range(n_test_runs):
                            _, return_mean, battle_won_mean = task2runner[task].run(local_task_encoding, test_mode=True, run_with_prior_encoding=True, learner=learner)
                            if return_mean is not None:
                                train_task_return_ls.append(return_mean)
                                train_task_battle_won_ls.append(battle_won_mean)
                    for task in main_args.test_task_ls:
                        episode_sample = task2offline_buffer[task].sample(1)
                        if episode_sample.device != main_args.device:
                            episode_sample.to(main_args.device)
                        _, local_task_encoding = learner.get_task_encoding(episode_sample, task)
                        task2runner[task].t_env = t_env
                        for _ in range(n_test_runs):
                            _, return_mean, battle_won_mean = task2runner[task].run(local_task_encoding, test_mode=True, run_with_prior_encoding=True, learner=learner)
                            if return_mean is not None:
                                test_task_return_ls.append(return_mean)
                                test_task_battle_won_ls.append(battle_won_mean)
                    train_task_return_mean = sum(train_task_return_ls) / len(train_task_return_ls)
                    test_task_return_mean = sum(test_task_return_ls) / len(test_task_return_ls)
                    train_task_battle_won_mean = sum(train_task_battle_won_ls) / len(train_task_battle_won_ls)
                    test_task_battle_won_mean = sum(test_task_battle_won_ls) / len(test_task_battle_won_ls)
                    logger.log_stat("train_task_return_mean_prior_role", train_task_return_mean, t_env)
                    logger.log_stat("train_task_battle_won_mean_prior_role", train_task_battle_won_mean, t_env)
                    logger.log_stat("test_task_return_mean_prior_role", test_task_return_mean, t_env)
                    logger.log_stat("test_task_battle_won_mean_prior_role", test_task_battle_won_mean, t_env)
            test_time_total += time.time() - test_start_time

            logger.console_logger.info("Step: {}/{}".format(t_env, t_max))
            logger.console_logger.info("Estimated time left: {}. Time passed: {}. FPS {:.2f}. Test time cost: {}".format(
                time_left(last_time, last_test_T, t_env, t_max), time_str(time.time() - start_time), (t_env - last_test_T) / (time.time() - last_time), time_str(test_time_total)
            ))
            last_time = time.time()
            last_test_T = t_env
        
        if main_args.save_model and (t_env-model_save_time >= main_args.save_model_interval or model_save_time==0):
            save_path = os.path.join(main_args.model_save_dir, str(t_env))
            os.makedirs(save_path, exist_ok=True)
            logger.console_logger.info("Saving models to {}".format(save_path))
            learner.save_models(save_path)
            model_save_time = t_env
        
        if (t_env - last_log_T) >= main_args.log_interval:
            last_log_T = t_env
            logger.log_stat("episode", episode, t_env)
            logger.print_recent_stats()

    logger.console_logger.info("Finish training sequential")
            
def args_sanity_check(config, _log):

    # set CUDA flags
    # config["use_cuda"] = True # Use cuda whenever possible!
    if config["use_cuda"] and not th.cuda.is_available():
        config["use_cuda"] = False
        _log.warning("CUDA flag use_cuda was switched OFF automatically because no CUDA devices are available!")

    if config["test_nepisode"] < config["batch_size_run"]:
        config["test_nepisode"] = config["batch_size_run"]
    else:
        config["test_nepisode"] = (config["test_nepisode"]//config["batch_size_run"]) * config["batch_size_run"]

    return config