import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(base_path)

import torch
import json
import shutil
import numpy as np
import wandb
import setproctitle
import math
import pprint
import time
from environment.wrapper import LMPromptEnv
from environment.DDP_wrapper import DDP_LMPromptEnv
from typing import Any, List, Dict
from dataloader.code.problem_loader import DDP_ProblemLoader
from argparse import Namespace
from evaluate_test.evaluate_utils import evalute_one_episode, evalute_batch_episode
from dataloader.code.dataset import BlendableDataset, RLFullDataset
from dataloader.code.data_samplers import build_training_data_loader
from dataloader.code.tokenizer import ContinuousScalarTokenizer
from train_test.config import parse_args
from train_test.optimizer_param_scheduler import OptimizerParamScheduler
from pathlib import Path
from tqdm import tqdm
from model import Gato
from environment.used.Env_bp_v1 import BP_V1, DDP_BP_V1
from environment.used.Env_bp_v2 import BP_V2, DDP_BP_V2
from environment.used.Env_cvrp_v1 import CVRP_V1, DDP_CVRP_V1
from environment.used.Env_cvrp_v2 import CVRP_V2, DDP_CVRP_V2
from environment.used.Env_cvrp_v3 import CVRP_V3, DDP_CVRP_V3
from environment.used.Env_ffsp_v1 import DDP_FFSP_V1
from environment.used.Env_ffsp_v2 import DDP_FFSP_V2
from environment.used.Env_atsp_v1 import ATSP_V1, DDP_ATSP_V1
from environment.used.Env_atsp_v2 import ATSP_V2, DDP_ATSP_V2
from environment.used.Env_tsp_v1 import TSP_V1, DDP_TSP_V1
from environment.used.Env_tsp_v2 import TSP_V2, DDP_TSP_V2
from environment.used.Env_tsp_v3 import TSP_V3, DDP_TSP_V3
from environment.used.Env_tsp_v4 import TSP_V4, DDP_TSP_V4
from environment.used.Env_op_v1 import OP_V1, DDP_OP_V1
from environment.used.Env_op_v2 import OP_V2, DDP_OP_V2
from environment.used.Env_op_v3 import OP_V3, DDP_OP_V3
from environment.used.Env_op_v4 import OP_V4, DDP_OP_V4
from environment.used.Env_pctsp_v1 import PCTSP_V1, DDP_PCTSP_V1
from environment.used.Env_pctsp_v2 import PCTSP_V2, DDP_PCTSP_V2
from environment.used.Env_pctsp_v3 import PCTSP_V3, DDP_PCTSP_V3
from environment.used.Env_spctsp_v2 import SPCTSP_V2, DDP_SPCTSP_V2
from environment.used.Env_spctsp_v3 import SPCTSP_V3, DDP_SPCTSP_V3
from utils.utils import set_seed, create_folder_overwrite_if_exist, create_folder_if_not_exist
from data.used.make_data import *
setproctitle.setproctitle("GATO-train@XXX")

def train(
    args: Namespace,
    gato: torch.nn.Module,
    datasets_train, 
    dataset_weights,
    optimizer:torch.optim.Optimizer,
    scheduler:OptimizerParamScheduler,
    envs: List[LMPromptEnv] = [],
    envs_problems: Dict = None,
    logger: Dict[str, Any] = None,
    seed: int=42
):    
    def _train_one_epoch():
        epoch_losses = []
        epoch_losses_dataset = {dataset_name: [] for dataset_name in args.eval_dataset_names}
        with tqdm(total=args.batch_num, desc=f'Trianing Epoch {epoch}') as pbar:
            for batch in train_data_iterator:
                rl_task_input, batch_data_info, batch_raw_obs = batch
                batch_dataset_name = [info[0] for info in batch_data_info]
                for dataset_name, dataset_idx in batch_data_info:
                    data_visited[dataset_name][dataset_idx] = 1
                    data_cnt[dataset_name] += 1
                
                if args.auto_batch_len and rl_task_input.seq_len.max() < args.n_position:
                    rl_task_input.apply(lambda x: x[:, :rl_task_input.seq_len.max()] if isinstance(x, torch.Tensor) and x.dim() == 2 else x)
                    #assert (rl_task_input.tensor_seq[:,-1]==args.special_tokens['<|>']).sum() >= 1

                rl_task_input.to(device=device)
                _, loss, loss_datasets, _ = gato(
                    rl_task_input, 
                    batch_dataset_name=batch_dataset_name, 
                    batch_raw_obs=batch_raw_obs
                )

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                new_lr, grad_accum_step = scheduler.step(increment=1)

                loss = loss.item()
                epoch_losses.append(loss)
                pbar.set_postfix({
                    'loss':'{:.2f}'.format(loss), 
                    'ave loss (latest 20)': '{:.2f}'.format(np.array(epoch_losses[-20:]).mean())
                })
                pbar.update()

                for dataset_name, dataset_loss in loss_datasets.items():
                    if dataset_loss != 0:
                        epoch_losses_dataset[dataset_name].append(dataset_loss.item())
                
        epoch_loss = np.array(epoch_losses).mean()
        epoch_loss_dataset = {name: np.array(losses).mean() for name, losses in epoch_losses_dataset.items()}
        return epoch_loss, epoch_loss_dataset, new_lr

    def _get_eval_loss():
        eval_losses = []
        eval_losses_dataset = {dataset_name: [] for dataset_name in args.eval_dataset_names}
        with tqdm(total=args.eval_batch_num, desc=f'Calculating eval loss') as pbar:
            for batch in valid_data_iterator:
                rl_task_input, batch_data_info, batch_raw_obs = batch
                batch_dataset_name = [info[0] for info in batch_data_info]

                if args.auto_batch_len and rl_task_input.seq_len.max() < args.n_position:
                    rl_task_input.apply(lambda x: x[:, :rl_task_input.seq_len.max()] if isinstance(x, torch.Tensor) and x.dim() == 2 else x)
                    #assert (rl_task_input.tensor_seq[:,-1]==args.special_tokens['<|>']).sum() >= 1

                rl_task_input.to(device=device)
                _, loss, loss_datasets, _ = gato(rl_task_input, batch_dataset_name=batch_dataset_name, batch_raw_obs=batch_raw_obs)
                
                loss = loss.item()
                eval_losses.append(loss)

                for dataset_name, dataset_loss in loss_datasets.items():
                    if dataset_loss != 0:
                        eval_losses_dataset[dataset_name].append(dataset_loss.item())

                pbar.set_postfix({
                    'loss':'{:.2f}'.format(loss), 
                    'ave loss': '{:.2f}'.format(np.array(eval_losses).mean())
                })
                pbar.update(1)

        eval_loss = np.array(eval_losses).mean()
        eval_loss_dataset = {name: np.array(losses).mean() for name, losses in eval_losses_dataset.items()}
        return eval_loss, eval_loss_dataset

    '''
    def _save_checkpoint(best_retrun):
        ave_return = np.mean(
            [np.mean(epi_return_greedy_cst[dataset_name]) for dataset_name in args.eval_dataset_names] + \
            [np.mean(epi_return_sample_cst[dataset_name]) for dataset_name in args.eval_dataset_names]
        )
        if args.save_strategy == 'interval':
            torch.save(
                gato.state_dict(), 
                f'{args.save_dir}/interval/{seed}/{epoch}_{round(ave_return,2)}.pt'
            )
        elif args.save_strategy == 'best':
            if ave_return > best_retrun:
                best_retrun = ave_return
                torch.save(
                    gato.state_dict(), 
                    f'{args.save_dir}/best/{round(best_retrun,2)}_seed{seed}_epoch{epoch}.pt'
                )                
        else:
            raise NotImplementedError
        return best_retrun
    '''
        
    def _eval_policy(sample_action=False, hard_action_constraint=False, ):
        desc_sample = 'sample' if sample_action else 'greedy'
        desc_constraint = 'constraint' if hard_action_constraint else 'free'
        desc = f'{desc_sample}-{desc_constraint}'
        episode_return = {k: {'AM':0, 'DB1':0} for k in args.eval_dataset_names}
        episode_obj = {k: 0 for k in args.eval_dataset_names}
        episode_safe_ratio = {k: [] for k in args.eval_dataset_names}
        episode_time = {k: [] for k in args.eval_dataset_names}

        for env_name, dataset_name, env in zip(args.eval_env_names, args.eval_dataset_names, envs):
            if args.use_ddp_env:
                ave_return, ave_obj, obj_std, ave_safe_ratio, ave_time_used, episode = evalute_batch_episode(
                    args=args, 
                    model=gato,
                    env=env,
                    problemloader=problemloader_dict[env_name],
                    cont_tokenizer=cont_tokenizer,
                    sample_action=sample_action,
                    hard_action_constraint=hard_action_constraint,
                    desc=f'Evaluating on {env_name} ({desc})',
                    device=device
                )
                
                # render greedy policy if necessary
                if logger is not None:
                    for i, epi in enumerate(episode):
                        logger[env_name].log_episode(
                            desc=desc,
                            is_eval=False,
                            episode=epi, 
                            epoch_num=0, 
                            episode_num=i,
                            time_used=-1,
                            seed=seed
                        )     
            else:
                problemloader = problemloader_dict[env_name]
                problemloader.reset()
                iters = eval_iters[env_name]
                epi_return, epi_obj, epi_safe, epi_time = {'AM':[], 'DB1':[]}, [], [], []                
                with tqdm(total=iters, desc=f'Evaluating on {env_name} ({desc})') as pbar:
                    for i in range(iters):
                        problem_info, problem_obj = problemloader.get_problem(1)
                        problem_info = (None if problem_info[0] is None else problem_info[0][0], problem_info[1][0], problem_info[2][0])
                        problem_obj = (problem_obj[0].item(), problem_obj[1].item())

                        # eval tasks
                        ep_ret, ep_obj, ep_safe, ep_len, ep_time, epi = evalute_one_episode(
                            args=args, 
                            model=gato, 
                            env=env, 
                            cont_tokenizer=cont_tokenizer,
                            sample_action=sample_action,
                            hard_action_constraint=hard_action_constraint,
                            problem_info=problem_info,
                            problem_obj=problem_obj
                        )
                        epi_safe.append(ep_safe)
                        epi_time.append(ep_time)
                        if ep_safe:
                            epi_return['AM'].append(ep_ret['AM'])
                            epi_return['DB1'].append(ep_ret['DB1'])
                            epi_obj.append(ep_obj)

                        # render greedy policy if necessary
                        if logger is not None and args.policy_logger:    
                            logger[env_name].log_episode(
                                desc=desc,
                                is_eval=False,
                                episode=epi, 
                                epoch_num=epoch, 
                                episode_num=i,
                                time_used=ep_time,
                                seed=seed
                            )

                        ave_safe_ratio = 0 if epi_safe == [] else np.mean(epi_safe)
                        ave_time_used = 0 if epi_time == [] else np.mean(epi_time)
                        ave_obj = 0 if epi_obj == [] else np.mean(epi_obj)
                        obj_std = 0 if epi_obj == [] else np.std(epi_obj)
                        ave_return_AM = 0 if epi_return['AM'] == [] else np.mean(epi_return['AM'])
                        ave_return_DB1 = 0 if epi_return['DB1'] == [] else np.mean(epi_return['DB1'])
                        ave_return = {'AM':ave_return_AM, 'DB1':ave_return_DB1}

                        info = {
                            'ret_AM': f'{ave_return_AM:.4f}',
                            'ret_DB1': f'{ave_return_DB1:.4f}',
                            'obj' : f'{ave_obj:.4f}',
                            'std': f'{obj_std:.4f}',
                            'time': f'{ave_time_used:.2f}',
                        }
                        pbar.set_postfix(info)
                        pbar.update()
                
                episode_return[dataset_name] = ave_return
                episode_obj[dataset_name] = ave_obj
                episode_safe_ratio[dataset_name] = ave_safe_ratio
                episode_time[dataset_name] = ave_time_used

        #print('')
        return episode_return, episode_obj, episode_safe_ratio, episode_time

    trained_time = 0
    device = torch.device(f"cuda:{args.device[0]}" if torch.cuda.is_available() and torch.cuda.device_count() >= args.device[0]+1 else "cpu")
    # build train and val dataloader
    (
        train_data_iterator, 
        valid_data_iterator, 
        problemloader_dict,
        dataset_train, 
        dataset_val
    ) = build_dataloader(args, datasets_train, dataset_weights, envs, envs_problems, seed)

    # figure out eval rollout times of each env 
    eval_iters = {
        env.env_name: 
            len(envs_problems[env.env_name].answer_list) if args.eval_iters_COP == 0 else 
            min(len(envs_problems[env.env_name].answer_list), args.eval_iters_COP) 
        for env in envs
    }
    
    # build cont_tokenizer
    cont_tokenizer = ContinuousScalarTokenizer(
        args.tokenizer_ver,
        args.num_continuous_bin, 
        args.discretize_mu, 
        args.discretize_M
    )
    
    # policy eval setting
    policy_eval_setting = {
        #'greedy-free': lambda: self._eval_policy(
        #        sample_action=False, hard_action_constraint=False),
        #'sample-free': lambda: self._eval_policy(
        #        sample_action=True, hard_action_constraint=False),
        'greedy-constraint': lambda: _eval_policy(
                sample_action=False, hard_action_constraint=True),
        #'sample-constraint': lambda: _eval_policy(
        #        sample_action=True, hard_action_constraint=True),
    }
    epi_return_greedy_cst = [0] * len(args.eval_dataset_names)
    epi_return_sample_cst = [0] * len(args.eval_dataset_names)

    # dataset visited info for visited ratio
    data_visited = {dataset.dataset_name: np.zeros(len(dataset)) for dataset in dataset_train.datasets}    
    data_cnt = {dataset.dataset_name: 0 for dataset in dataset_train.datasets}    

    # =============================== start training ===============================
    best_retrun = 0
    for epoch in range(0, args.train_iters + 1):  
        log_dict = {}

        # Calculate validation losses & Evaluate policy performance at specified epoch intervals
        if epoch % args.eval_interval == 0 or epoch == args.train_iters:
            gato.eval()
            with torch.no_grad():
                # validation losses
                gato.transformer.same_length = False        # use normal context length when loss calculating (TransformerXL back bone)
                eval_loss, eval_loss_dataset = _get_eval_loss()
                
                log_dict.update({"losses/eval_loss": eval_loss})
                log_dict.update(
                    {f'losses/eval_{dataset_name[:-3]}': loss
                    for dataset_name, loss in eval_loss_dataset.items()}
                )

                # Evaluate current policy
                gato.transformer.same_length = args.use_mem # use fixed context length when rollout with mem (TransformerXL back bone)
                for setting, eval_func in policy_eval_setting.items():
                    epi_return, epi_obj, epi_safe, epi_time= eval_func()
                    log_dict.update(
                        {f'eval_{dataset_name[:-3]}/return_AM({setting})': np.mean(epi_return[dataset_name]['AM'])
                        for dataset_name in args.eval_dataset_names}
                    )
                    log_dict.update(
                        {f'eval_{dataset_name[:-3]}/return_DB1({setting})': np.mean(epi_return[dataset_name]['DB1'])
                        for dataset_name in args.eval_dataset_names}
                    )
                    log_dict.update(
                        {f'eval_{dataset_name[:-3]}/obj({setting})': np.mean(epi_obj[dataset_name])
                        for dataset_name in args.eval_dataset_names}
                    )
                    log_dict.update(
                        {f'eval_{dataset_name[:-3]}/safe({setting})': np.mean(epi_safe[dataset_name])
                        for dataset_name in args.eval_dataset_names}
                    )
                    log_dict.update(
                        {f'eval_{dataset_name[:-3]}/time({setting})': np.mean(epi_time[dataset_name])
                        for dataset_name in args.eval_dataset_names}
                    )

                    '''
                    if setting == 'greedy-constraint':
                        epi_return_greedy_cst = epi_return
                    elif setting == 'sample-constraint':
                        epi_return_sample_cst = epi_return
                    '''
        '''
        # Save the model at the specified epoch interval
        if args.save_ckpt and epoch != 0 and epoch % args.save_interval == 0:
            best_retrun = _save_checkpoint(best_retrun)
        '''

        if epoch == args.train_iters:
            break
        
        # one training epoch
        gato.train()
        gato.transformer.same_length = False        # use normal context length when loss calculating (TransformerXL back bone)
        start_time = time.time()
        epoch_loss, epoch_loss_dataset, new_lr = _train_one_epoch()
        trained_time += time.time() - start_time

        # log train data if necessary
        if logger is not None and args.traindata_logger:
            logged_data = dataset_train.get_logged_data()
            for env_name in args.eval_env_names:
                logger[env_name].log_data(logged_data, seed=seed, is_train=True)

        # update log info
        log_dict.update({"info/epoch": epoch, "losses/train_loss": epoch_loss, "info/lr": new_lr, 'info/trained_time': trained_time})
        for dataset_name in args.eval_dataset_names:
            log_name = dataset_name[:dataset_name.find('_')]
            wandb.run.summary[f"info/{log_name}_visited"] = np.sum(data_visited[dataset_name]).item()
            wandb.run.summary[f"info/{log_name}_num"] = data_cnt[dataset_name]
            log_dict.update({f'info/{log_name}_ratio': np.mean(data_visited[dataset_name]).item()})
            log_dict.update({f'eval_{dataset_name[:-3]}/train_loss': epoch_loss_dataset[dataset_name]})
            
        # log to wandb
        wandb.log(log_dict)

def build_dataloader(args:Namespace, datasets:List[RLFullDataset], dataset_weights:List[float], envs:List, envs_problems:Dict, seed:int):
    # split training set and evaluation set
    datasets_train, datasets_val = [], []
    for dataset in datasets:
        dataset_train, dataset_val = dataset.split_dataset(args.split)
        datasets_train.append(dataset_train)
        datasets_val.append(dataset_val)
        
    # build BlendableDataset
    dataset_train = BlendableDataset(
        datasets_train, 
        dataset_weights,
        batch_size=args.batch_size,
        #check_visited=True,
        log_data=args.traindata_logger,
        with_dataset_info=True
    )
    dataset_val = BlendableDataset(
        datasets_val, 
        dataset_weights,
        batch_size=args.eval_batch_size,
        with_dataset_info=True
    )

    # build dataloader
    if args.batch_num == 0:
        args.batch_num = math.ceil(len(dataset_train) / args.batch_size)
        sample_num_per_training_epoch = None
    else:
        sample_num_per_training_epoch = args.batch_num * args.batch_size
    dataloader_train = build_training_data_loader(
        args, 
        dataset_train, 
        epoch_total_samples=sample_num_per_training_epoch, 
        is_eval=False,
        seed=seed
    )

    if args.eval_batch_num == 0:
        args.eval_batch_num = math.ceil(len(dataset_val) / args.eval_batch_size)
        sample_num_per_evaluation_epoch = None
    else:
        sample_num_per_evaluation_epoch = args.eval_batch_num * args.eval_batch_size
    dataloader_val = build_training_data_loader(
        args, 
        dataset_val, 
        epoch_total_samples=sample_num_per_evaluation_epoch,
        is_eval=True,
        seed=seed
    )

    problemloader_dict = {}
    for i, (name, dataset) in enumerate(envs_problems.items()):
        env = envs[i]
        assert env.env_name == name
        problemloader_dict[name] = DDP_ProblemLoader(dataset, env)

    print('-'*35)
    print(f'valid Dataset size:\t{len(dataset_val)}')
    print(f'valid sample num:  \t{len(dataloader_val)}')
    print(f'train sample num:  \t{len(dataloader_train)}')
    print(f'train Dataset Size:\t{len(dataset_train)}')
    pp = pprint.PrettyPrinter(indent=4)
    for i, dataset in enumerate(dataset_train.datasets):
        env_name = dataset.env_name
        problemloader = problemloader_dict[env_name]
        rnd_obj_value = problemloader.random_obj_array[problemloader.random_obj_array != 0].mean()
        best_obj_value = problemloader.best_obj_array[problemloader.best_obj_array != 0].mean()
        info = {'size': len(dataset), 'weight': dataset_weights[i]}
        print(f'\t{dataset.dataset_name:10}', end='\t')
        pp.pprint({k: round(v,3) for k,v in info.items()})
    print('-'*35)

    return dataloader_train, dataloader_val, problemloader_dict, dataset_train, dataset_val

def get_train_objs(args):
    # load datasets & datasets weight
    weight = json.loads(args.dataset_weights)

    basic_env_builders, ddp_env_builders = [], []
    datasets_train, datasets_prompt, datasets_prompt_ddp, envs_problems = [], [], [], {}
    for env_name in weight.keys():
        if env_name == 'Env_BP_V1':
            envs_problems[env_name] = get_bp_data_v1(args, data_type=args.eval_problem_set)
            datasets_train += get_bp_data_v1(args, data_type='train')[0]
            dataset, ddp_datset = get_bp_data_v1(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_BP_V1(item_num=args.bp_item_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: BP_V1(item_num=args.bp_item_num))
        elif env_name == 'Env_BP_V2':
            envs_problems[env_name] = get_bp_data_v2(args, data_type=args.eval_problem_set)
            datasets_train += get_bp_data_v2(args, data_type='train')[0]
            dataset, ddp_datset = get_bp_data_v2(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_BP_V2(item_num=args.bp_item_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: BP_V2(item_num=args.bp_item_num))
        elif env_name == 'Env_FFSP_V1':
            envs_problems[env_name] = get_ffsp_data_v1(args, data_type=args.eval_problem_set)
            datasets_train += get_ffsp_data_v1(args, data_type='train')[0]
            dataset, ddp_datset = get_ffsp_data_v1(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_FFSP_V1(job_cnt=args.ffsp_job_num, batch_size=args.problem_batch_size))
            #basic_env_builders.append(lambda: FFSP_V1(num_nodes=args.ffsp_job_num))
        elif env_name == 'Env_FFSP_V2':
            envs_problems[env_name] = get_ffsp_data_v2(args, data_type=args.eval_problem_set)
            datasets_train += get_ffsp_data_v2(args, data_type='train')[0]
            dataset, ddp_datset = get_ffsp_data_v2(args, data_type='train', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_FFSP_V2(job_cnt=args.ffsp_job_num, batch_size=args.problem_batch_size))
            #basic_env_builders.append(lambda: FFSP_V2(num_nodes=args.ffsp_job_num))
        elif env_name == 'Env_ATSP_V1':
            envs_problems[env_name] = get_atsp_data_v1(args, data_type=args.eval_problem_set)
            datasets_train += get_atsp_data_v1(args, data_type='train')[0]
            dataset, ddp_datset = get_atsp_data_v1(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_ATSP_V1(num_nodes=args.atsp_city_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: ATSP_V1(num_nodes=args.atsp_city_num))
        elif env_name == 'Env_ATSP_V2':
            envs_problems[env_name] = get_atsp_data_v2(args, data_type=args.eval_problem_set)
            datasets_train += get_atsp_data_v2(args, data_type='train')[0]
            dataset, ddp_datset = get_atsp_data_v2(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_ATSP_V2(num_nodes=args.atsp_city_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: ATSP_V2(num_nodes=args.atsp_city_num))
        elif env_name == 'Env_TSP_V1':
            envs_problems[env_name] = get_tsp_data_v1(args, data_type=args.eval_problem_set)
            datasets_train += get_tsp_data_v1(args, data_type='train')[0]
            dataset, ddp_datset = get_tsp_data_v1(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_TSP_V1(num_nodes=args.tsp_city_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: TSP_V1(num_nodes=args.tsp_city_num))
        elif env_name == 'Env_TSP_V2':
            envs_problems[env_name] = get_tsp_data_v2(args, data_type=args.eval_problem_set)
            datasets_train += get_tsp_data_v2(args, data_type='train')[0]
            dataset, ddp_datset = get_tsp_data_v2(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_TSP_V2(num_nodes=args.tsp_city_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: TSP_V2(num_nodes=args.tsp_city_num))
        elif env_name == 'Env_TSP_V3':
            envs_problems[env_name] = get_tsp_data_v3(args, data_type=args.eval_problem_set)
            datasets_train += get_tsp_data_v3(args, data_type='train')[0]
            dataset, ddp_datset = get_tsp_data_v3(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_TSP_V3(num_nodes=args.tsp_city_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: TSP_V3(num_nodes=args.tsp_city_num))
        elif env_name == 'Env_PCTSP_V1':
            envs_problems[env_name] = get_pctsp_data_v1(args, data_type=args.eval_problem_set)
            datasets_train += get_pctsp_data_v1(args, data_type='train')[0]
            dataset, ddp_datset = get_pctsp_data_v1(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_PCTSP_V1(node_num=args.pctsp_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: PCTSP_V1(node_num=args.pctsp_node_num))
        elif env_name == 'Env_PCTSP_V3':
            envs_problems[env_name] = get_pctsp_data_v3(args, data_type=args.eval_problem_set)
            datasets_train += get_pctsp_data_v3(args, data_type='train')[0]
            dataset, ddp_datset = get_pctsp_data_v3(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_PCTSP_V3(node_num=args.pctsp_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: PCTSP_V3(node_num=args.pctsp_node_num))
        elif env_name == 'Env_SPCTSP_V2':
            envs_problems[env_name] = get_spctsp_data_v2(args, data_type=args.eval_problem_set)
            datasets_train += get_spctsp_data_v2(args, data_type='train')[0]
            dataset, ddp_datset = get_spctsp_data_v2(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_SPCTSP_V2(node_num=args.spctsp_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: SPCTSP_V2(node_num=args.spctsp_node_num))
        elif env_name == 'Env_SPCTSP_V3':
            envs_problems[env_name] = get_spctsp_data_v3(args, data_type=args.eval_problem_set)
            datasets_train += get_spctsp_data_v3(args, data_type='train')[0]
            dataset, ddp_datset = get_spctsp_data_v3(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_SPCTSP_V3(node_num=args.spctsp_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: SPCTSP_V3(node_num=args.spctsp_node_num))
        elif env_name == 'Env_OP_V1':
            envs_problems[env_name] = get_op_data_v1(args, data_type=args.eval_problem_set)
            datasets_train += get_op_data_v1(args, data_type='train')[0]
            dataset, ddp_datset = get_op_data_v1(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset                   
            ddp_env_builders.append(lambda: DDP_OP_V1(node_num=args.op_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: OP_V1(node_num=args.op_node_num))
        elif env_name == 'Env_OP_V2':
            envs_problems[env_name] = get_op_data_v2(args, data_type=args.eval_problem_set)
            datasets_train += get_op_data_v2(args, data_type='train')[0]
            dataset, ddp_datset = get_op_data_v2(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset                   
            ddp_env_builders.append(lambda: DDP_OP_V2(node_num=args.op_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: OP_V2(node_num=args.op_node_num))
        elif env_name == 'Env_OP_V4':
            envs_problems[env_name] = get_op_data_v4(args, data_type=args.eval_problem_set)
            datasets_train += get_op_data_v4(args, data_type='train')[0]
            dataset, ddp_datset = get_op_data_v4(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset                   
            ddp_env_builders.append(lambda: DDP_OP_V4(node_num=args.op_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: OP_V4(node_num=args.op_node_num))
        elif env_name == 'Env_CVRP_V1':
            envs_problems[env_name] = get_cvrp_data_v1(args, data_type=args.eval_problem_set)
            datasets_train += get_cvrp_data_v1(args, data_type='train')[0]
            dataset, ddp_datset = get_cvrp_data_v1(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_CVRP_V1(node_num=args.cvrp_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: CVRP_V1(node_num=args.cvrp_node_num))
        elif env_name == 'Env_CVRP_V3':
            envs_problems[env_name] = get_cvrp_data_v3(args, data_type=args.eval_problem_set)
            datasets_train += get_cvrp_data_v3(args, data_type='train')[0]
            dataset, ddp_datset = get_cvrp_data_v3(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_CVRP_V3(node_num=args.cvrp_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: CVRP_V3(node_num=args.cvrp_node_num))
        else:
            raise False
    
    dataset_weights = list(weight.values())
    args.eval_env_names = list(weight.keys())
    args.eval_dataset_names = [dataset.dataset_name for dataset in datasets_train]    
    
    '''
    # make sure all data are tokenized properly 
    for dataset in datasets_train:
        for i in tqdm(range(len(dataset)), total=len(dataset), desc=f'Checking tokenize format of {dataset.dataset_name}_train'):
            dataset.check_token_list_format(dataset.get(i, with_raw_obs=False))
    if not args.use_ddp_env:
        for dataset in datasets_prompt:
            for i in tqdm(range(len(dataset)), total=len(dataset), desc=f'Checking tokenize format of {dataset.dataset_name}_prompt'):
                dataset.check_token_list_format(dataset.get(i, with_raw_obs=False))
    '''
    if args.mlp_emb_items != {}:
        for dataset in datasets_train:
            for obs_name, obs_dim in dataset.obs_dims_for_spec.items():
                for linear_emb_obs_info in args.mlp_emb_items.values():
                    if obs_name in linear_emb_obs_info['item_name']:
                        assert obs_dim % linear_emb_obs_info['dim'] == 0

    # build envs for evaluation
    eval_prompt_strat = args.prompt_strategy.split(";")[-1] # moving_prompt
    if args.use_ddp_env:
        envs = [DDP_LMPromptEnv(env_builer(), args, prompt_dataset, eval_prompt_strat) for env_builer, prompt_dataset in zip(ddp_env_builders, datasets_prompt_ddp)]
    else:
        envs = [LMPromptEnv(env_builer(), args, prompt_dataset, eval_prompt_strat) for env_builer, prompt_dataset in zip(basic_env_builders, datasets_prompt)]

    # build episode render if we need to check generated episodes during training
    logger = None
    if args.policy_logger or args.traindata_logger:
        logger = {env_name: EXAMPLE_RENDER[env_name]() for env_name in args.eval_env_names}    
        prompts = datasets_prompt if not args.use_ddp_env else datasets_prompt_ddp
        for dataset in prompts:
            create_folder_overwrite_if_exist(f'{base_path}/visualize/train/{dataset.env_name}/{dataset.dataset_name}')
        
    return datasets_train, dataset_weights, envs, envs_problems, logger

def get_args_ready():
    # NOTE(XXX): Only a part of paras in args structure are used now
    args = parse_args()
    world_size = int(os.environ.get("WORLD_SIZE", default='1'))
    args.world_size = world_size

    # core paras
    args.model = 'llama'            # 'llama' or 'transformer_xl'
    args.n_embed = 360              # embedding dimension
    args.n_q_head = 6               # attention query head num (for llama GQA)
    args.n_kv_head = 6              # attention key/value head num (for llama GQA)
    args.n_head = args.n_q_head     # attention head num (for TransformerXL)    
    args.n_position = 1000           # model input sequence length (max context length)
    args.auto_batch_len = True      # Automatically clip the sample length to the maximum length in the batch
    args.n_layer = 8                # transformer block num
    args.rms_norm_eps = 1e-6        # RMS Norm epsilon (for llama)
    args.num_workers = 0
    if args.model == 'llama':
        assert args.n_q_head % args.n_kv_head == 0
    
    # trainig paras
    args.is_obs_pretrain = False
    args.dataset_weights = '{"Env_BP_V1":1}'
    #args.dataset_weights = '{"Env_PCTSP_V3":1}'
    #args.dataset_weights = '{"Env_PCTSP_V3":1, "Env_BP_V1":1}'
    #args.dataset_weights = '{"Env_SPCTSP_V2":1,"Env_BP_V2":1,"Env_TSP_V2":1,"Env_PCTSP_V1":1,"Env_OP_V2":1,"Env_CVRP_V1":1}'
    args.train_iters = 500         # training epoch num
    args.batch_num = 48             # training batch num per epoch
    args.batch_size = 10            # training batch size
    
    # loss eval paras
    args.eval_batch_size = 10       # eval batch size
    args.eval_batch_num = 3        # eval batch num
    args.eval_interval = 10         # epoch interval for eval loss calculating
    assert args.train_iters % args.eval_interval == 0
    if args.is_obs_pretrain:
        args.eval_interval = args.train_iters

    # policy eval paras
    args.problem_batch_size = 5   # eval problem batch_size (per GPU)
    args.problem_batch_num = 3
    args.eval_iters_COP = world_size * args.problem_batch_size * args.problem_batch_num
    args.eval_max_step_size = 1000  # max rollout timestep for policy evaluation
    args.use_default_policy_obj = False # Whether to use the default random policy obj value to calculate epi quality in evaluation
    args.use_ddp_env = True
    args.use_mem = False
    #args.use_mem = not args.use_prefix
    #assert args.use_mem == (not args.use_prefix)

    # prompt paras
    args.prompt_strategy = "stochastic_subseq;moving_prompt"
    args.prompt_prob = 0.25               
    args.prompt_ratio = 0.5                      
    args.prompt_at_final_transition_prob = 0.5
    args.use_prefix = True
    args.use_dynamic_prefix = True
    args.use_prompt = False
    assert args.use_prompt ^ args.use_prefix or args.use_prompt == args.use_prefix == False
    if not args.use_prefix:
        assert not args.use_dynamic_prefix

    # optimizer paras
    args.lr_max = 2.5e-4
    args.lr_begin = 1.0e-5
    args.lr_warmup_ratio = 0.05
    args.lr_decay_ratio = 0.95
    args.lr_decay_factor = 10
    args.lr_decay_style = "cosine"
    args.start_weight_decay = 0.1
    args.end_weight_decay = args.start_weight_decay
    args.weight_decay_incr_style = "constant"
    args.use_checkpoint_opt_param_scheduler = True
    args.override_opt_param_scheduler = not args.use_checkpoint_opt_param_scheduler

    # embedding paras
    args.tokenizer_ver = 'v2'
    args.discretize_mu = 15
    args.discretize_M = 4
    args.num_continuous_bin = 1800
    args.num_discrete_values = 200
    args.mlp_emb_items = {}
    '''
    args.mlp_emb_items = {               # each item here corresponding to a linear layer for embedding, and the 'item_name' items are the obs item name of MDP episode data
        'position': {'dim': 2, 'item_name': ['position', 'pos_depot', 'pos_node']},
    }
    '''
    if args.auto_batch_len:
        assert args.mlp_emb_items == {}

    # env paras
    args.ffsp_job_num = 20
    args.atsp_city_num = 20
    args.tsp_city_num = 20
    args.op_node_num = 20
    args.pctsp_node_num = 10
    args.spctsp_node_num = 20
    args.cvrp_node_num = 20
    args.bp_item_num = 20

    common_data_num = 50
    args.data_num_bp = common_data_num
    args.data_num_ffsp = common_data_num
    args.data_num_atsp = common_data_num
    args.data_num_tsp = common_data_num
    args.data_num_op = common_data_num
    args.data_num_cvrp = common_data_num
    args.data_num_pctsp = common_data_num
    args.data_num_spctsp = common_data_num

    args.special_tokens = {
        "<|>": args.num_discrete_values + args.num_continuous_bin,
        "<X>": args.num_discrete_values + args.num_continuous_bin + 1,
    }
    if args.use_prefix:
        assert "<X>" in args.special_tokens

    # ckpt paras
    args.exp_profile = 'test'
    exp_name = f'{args.exp_profile}_{args.n_position}_{args.n_embed}_{args.n_head}_{args.n_layer}' if args.model == 'transformer_xl' else \
                f'{args.exp_profile}_{args.n_position}_{args.n_embed}_{args.n_q_head}|{args.n_kv_head}_{args.n_layer}'
    args.save_dir = f'{base_path}/ckpt/{exp_name}'
    args.save_strategy = 'best'
    args.save_interval = args.eval_interval
    assert args.save_interval % args.eval_interval == 0

    # other paras
    args.traj_type = 'all'
    args.dataset_distribution = 'uniform'
    args.eval_problem_set = 'train_problem'     # train_problem or problem
    args.dataloader_type = "sequential"         # "sequential" or "random"
    args.disable_visited_obs = True
    args.device = [0, ]             # GPU device
    args.seeds = [42, ]             # random seeds
    args.wandb = False              # log the exp curve to wandb or not
    args.policy_logger = False      # whether to check the generated episodes during training
    args.traindata_logger = False   # whether to log the sample idx during training
    args.save_ckpt = False          # save model paras ckpt during training or not       
    args.save_snapshot = False      # save snapshot during DDP training or not 
    assert len(args.device) == 1    # 该脚本仅支持单卡
        
    if not args.wandb:       
        os.environ['WANDB_MODE'] = 'offline'
    else:
        create_folder_if_not_exist(f'{base_path}/Wandb')

    # create floder to save ckpts and hyperparas if we need
    if (args.save_ckpt or args.save_snapshot):
        create_folder_overwrite_if_exist(f'{args.save_dir}/{args.save_strategy}')
        with open(f'{args.save_dir}/config.json', 'w') as f:
            f.write(json.dumps(vars(args), indent=4))
        shutil.copy2(
            src=f'{base_path}/train_test/train.py',
            dst=f'{args.save_dir}/train.py',
        )

    return exp_name, args

if __name__ == "__main__":
    # get hyper paras ready
    exp_name, args = get_args_ready()

    # load training objs
    datasets_train, dataset_weights, envs, envs_problems, logger = get_train_objs(args)
    
    # train
    for seed in args.seeds:
        args.seed = seed
        if args.save_dir and args.save_strategy == 'interval':
            create_folder_overwrite_if_exist(f'{args.save_dir}/interval/{seed}')

        # build model instance
        device = torch.device(f"cuda:{args.device[0]}" if torch.cuda.is_available() and torch.cuda.device_count() >= args.device[0]+1 else "cpu")
        gato = Gato(args).to(device).train()
    
        # optimizer & lr scheduler
        optimizer = gato.transformer.configure_optimizers()
        scheduler = OptimizerParamScheduler(args, optimizer)

        # This unique id is necessary for log resuming
        wandb_id = wandb.util.generate_id() 

        with wandb.init(
            # set the wandb project where this run will be logged
            project="gato-ddp-test",
            dir = Path(f'{base_path}/Wandb'),
            group = exp_name,
            name = f"seed_{seed}",
            id = wandb_id,
            resume = 'allow',
            config=args,
        ):
            wandb.watch(gato, log='all', log_freq=100)
            train(args, gato, datasets_train, dataset_weights, optimizer, scheduler, envs, envs_problems, logger, seed)

    wandb.finish()

