import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(base_path)

from dataclasses import dataclass, asdict
from collections import OrderedDict
from typing import Any, Dict, List
import os
import torch
import wandb
import math
import time
from dataloader.code.problem_loader import DDP_ProblemLoader
from torch.nn.parallel import DistributedDataParallel as DDP
import fsspec
import pprint
import torch.distributed as dist
import numpy as np
from tqdm import tqdm
from train_test.optimizer_param_scheduler import OptimizerParamScheduler
from dataloader.code.dataset import BlendableDataset, RLFullDataset
from dataloader.code.data_samplers import build_training_data_loader
from dataloader.code.tokenizer import ContinuousScalarTokenizer
from model import Gato
from model.transformer_xl import TransformerXL
from model.llama import TrajLlama
from utils.utils import set_seed
from evaluate_test.evaluate_utils import evalute_one_episode, evalute_batch_episode
from thop import profile

@dataclass
class Snapshot:
    model_state: 'OrderedDict[str, torch.Tensor]'
    optimizer_state: Dict[str, Any]
    scheduler_state: Dict[str, Any]
    finished_epoch: int
    best_retrun: int
    best_eval_loss: float
    trained_time: float
    wandb_id: str

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, delta=0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
        """
        self.patience = patience
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.counter = 0

class Trainer:
    def __init__(self, args, seed, wandb_id, envs, logger, datasets, dataset_weights, envs_problems):
        self.seed = seed
        self.args = args
        self.wandb_id = wandb_id
        self.envs = envs
        self.logger = logger
        self.datasets = datasets
        self.snapshot_path = f'{self.args.save_dir}/snapshot_seed{seed}.pt'
        set_seed(seed, self.envs)

        # policy evaluation setting
        self.eval_setting = {
            #'greedy-free': lambda: self._eval_policy(
            #        sample_action=False, hard_action_constraint=False),
            #'sample-free': lambda: self._eval_policy(
            #        sample_action=True, hard_action_constraint=False),
            'greedy-constraint': lambda: self._eval_policy(
                    sample_action=False, hard_action_constraint=True),
            #'sample-constraint': lambda: self._eval_policy(
            #        sample_action=True, hard_action_constraint=True),
        }

        # set torchrun variables
        self.local_rank = int(os.environ["LOCAL_RANK"]) 
        self.global_rank = int(os.environ["RANK"])      
        self.world_size = int(os.environ["WORLD_SIZE"])
        
        # wrap the model with DDP. this step will synch model across all the processes.
        self.gato = Gato(args).to(self.local_rank).train()
        self.gato = DDP(self.gato, device_ids=[self.local_rank])
        self.raw_gato = self.gato.module if hasattr(self.gato, "module") else self.gato
        self.early_stopping = EarlyStopping(patience=args.early_stopping_patience, delta=args.early_stopping_delta)

        # initialize train states
        self.optimizer = self.raw_gato.transformer.configure_optimizers()  
        self.scheduler = OptimizerParamScheduler(self.args, self.optimizer)    
        self.save_every = self.args.snapshot_save_interval
        self.best_retrun = None
        self.best_eval_loss = None
        self.epoch_start = None
        self.epoch_now = None
        self.current_grad_accum_step = None
        self.trained_time = None

        self.grad_batch_cnt = 1         # for grad_accum
        self.model_macs_per_batch = 0

        # load snapshot if available.
        self._load_snapshot()

        # data stuff
        (
            self.dataloader_train, 
            self.dataloader_val, 
            self.problemloader_dict,
            self.dataset_train, 
            self.dataset_val,
        ) = self._prepare_dataloader(datasets, dataset_weights, envs_problems)        
        self.cont_tokenizer = ContinuousScalarTokenizer(
            self.args.tokenizer_ver,
            self.args.num_continuous_bin, 
            self.args.discretize_mu, 
            self.args.discretize_M
        )

        # check MACs and params quantity
        model_macs_per_batch_to_share = torch.tensor([0,], dtype=torch.float32).cuda()
        if self.local_rank == 0:
            def _get_dummy_data():
                rl_task_input, batch_data_info, batch_raw_obs = next(self.dataloader_train.__iter__())
                self.dataloader_train.sampler.reset()
                rl_task_input.to(device=self.local_rank)
                for i in range(len(batch_raw_obs)):
                    for k in batch_raw_obs[i].keys():
                        batch_raw_obs[i][k] = batch_raw_obs[i][k].to(device=self.local_rank)
                batch_dataset_name = [info[0] for info in batch_data_info]
                dummy_data = {'tasks_input': rl_task_input, 'compute_loss': True, 'mems': None, 'batch_dataset_name': batch_dataset_name, 'batch_raw_obs': batch_raw_obs}
                return dummy_data
            
            def _get_dummy_model():
                if args.model == 'transformer_xl':
                    dummy_model = TransformerXL(args).to(self.local_rank)
                    dummy_model.same_length = False
                elif args.model == 'llama':
                    dummy_model = TrajLlama(args).to(self.local_rank)
                else:
                    raise NotImplementedError
                return dummy_model
            
            def _get_parameter_size_in_gb(param):
                return param.numel() * param.element_size() / 1024 / 1024 / 1024
            
            # get params quantity by count directly 
            dummy_data = _get_dummy_data()
            dummy_model = _get_dummy_model()
            total_block_params = sum(p.numel() for n, p in dummy_model.named_parameters() if 'embedding' not in n and 'encoding' not in n and n != 'lm_head')
            total_size_gb = sum(_get_parameter_size_in_gb(p) for p in dummy_model.parameters())
            #for name, param in self.named_parameters():
            #    param_size_gb = _get_parameter_size_in_gb(param)
            #    print(f"Layer: {name} \t| Parameters: {param.numel()} \t| Size: {param_size_gb:.2f} GB \t| Data Type: {param.dtype}")

            # get macs with thop lib 
            dummy_model = _get_dummy_model()
            macs, _ = profile(dummy_model, tuple(dummy_data.values()), verbose=False)
            model_macs_per_batch = macs * self.world_size
            print(f"Total block params:  \t{total_block_params}")
            print(f"Total parameter size:\t{total_size_gb:.2f} GB")
            print(f'MACs per batch data: \t{int(model_macs_per_batch /1e6)} M')
            model_macs_per_batch_to_share = torch.tensor([model_macs_per_batch,], dtype=torch.float32).cuda()    
        dist.barrier()
        dist.broadcast(model_macs_per_batch_to_share, src=0)
        self.model_macs_per_batch = model_macs_per_batch_to_share.item()
        # Mixed Precision Training
        if self.args.use_amp: 
            self.scaler = torch.cuda.amp.GradScaler()

    def _prepare_dataloader(self, 
        datasets:List[RLFullDataset], 
        dataset_weights:List[float],
        envs_problems:Dict
    ):
        # split training set and evaluation set
        datasets_train, datasets_val = [], []
        for dataset in datasets:
            dataset_train, dataset_val = dataset.split_dataset(self.args.split)
            datasets_train.append(dataset_train)
            datasets_val.append(dataset_val)
            
        # build BlendableDataset
        dataset_train = BlendableDataset(
            datasets_train, 
            dataset_weights,
            batch_size=self.args.batch_size,
            log_data=self.args.traindata_logger,
            with_dataset_info=True
        )
        dataset_val = BlendableDataset(
            datasets_val, 
            dataset_weights,
            batch_size=self.args.eval_batch_size,
            with_dataset_info=True
        )

        # build dataloader
        if self.args.batch_num == 0:
            self.args.batch_num = math.ceil(len(dataset_train) / self.args.batch_size)
            sample_num_per_training_epoch = None
        else:
            sample_num_per_training_epoch = self.args.batch_num * self.args.batch_size * self.world_size
        dataloader_train = build_training_data_loader(
            self.args, 
            dataset_train, 
            epoch_total_samples=sample_num_per_training_epoch, 
            is_eval=False,
            seed=self.seed,
            current_epoch=self.epoch_now
        )

        if self.args.eval_batch_num == 0:
            self.args.eval_batch_num = math.ceil(len(dataset_val) / self.args.eval_batch_size)
            sample_num_per_evaluation_epoch = None
        else:
            sample_num_per_evaluation_epoch = self.args.eval_batch_num * self.args.eval_batch_size * self.world_size
        dataloader_val = build_training_data_loader(
            self.args, 
            dataset_val, 
            epoch_total_samples=sample_num_per_evaluation_epoch,
            is_eval=True,
            seed=self.seed,
            current_epoch=self.epoch_now
        )

        problemloader_dict = {}
        for i, (name, dataset) in enumerate(envs_problems.items()):
            env = self.envs[i]
            assert env.env_name == name
            problemloader_dict[name] = DDP_ProblemLoader(dataset, env)

        if self.local_rank == 0:
            print('-'*35)
            #print(f'valid Dataset size:             \t{len(dataset_val)}')
            #print(f'valid sample num per epoch:    \t{len(dataloader_val.sampler)}')
            print(f'train sample per epoch: \t{self.world_size * len(dataloader_train.sampler)}')
            print(f'train sample all epoch: \t{self.args.train_iters * self.world_size * len(dataloader_train.sampler)}')
            print(f'train Dataset Size:     \t{len(dataset_train)}')
            pp = pprint.PrettyPrinter(indent=4)
            for i, dataset in enumerate(dataset_train.datasets):
                env_name = dataset.env_name
                problemloader = problemloader_dict[env_name]
                rnd_obj_value = problemloader.random_obj_array[problemloader.random_obj_array != 0].mean()
                best_obj_value = problemloader.best_obj_array[problemloader.best_obj_array != 0].mean()
                info = {'size': len(dataset), 'weight': dataset_weights[i]}
                info[f'rnd_obj[{len(problemloader)}]'] = rnd_obj_value
                info[f'best_obj[{len(problemloader)}]'] = best_obj_value
                print(f'\t{dataset.dataset_name:10}', end='\t')
                pp.pprint({k: round(v,3) for k,v in info.items()})
            print('-'*35)

        return dataloader_train, dataloader_val, problemloader_dict, dataset_train, dataset_val

    def _load_snapshot(self):
        try:
            snapshot = fsspec.open(self.snapshot_path)   # fsspec 为各种后端存储系统提供统一的 Python 接口，可以用相同的语法打开本地、AWS S3 和 GCS 等各种云存储平台的文件
            with snapshot as f:
                snapshot_data = torch.load(f, map_location="cpu")
                if 'best_eval_loss' not in snapshot_data:
                    snapshot_data['best_eval_loss'] = float('inf')
        except FileNotFoundError:    
            self.best_retrun = -float('inf')
            self.best_eval_loss = float('inf')
            self.epoch_start = 0
            self.epoch_now = 0
            self.current_grad_accum_step = self.args.start_grad_accum
            self.trained_time = 0

            if self.args.pretrained_ckpt is not None:
                self.raw_gato.load_state_dict(torch.load(f'{base_path}/ckpt/pretrain/{self.args.pretrained_ckpt}'))
                if self.local_rank == 0:
                    print(f"Pretrained model {self.args.pretrained_ckpt} loaded.\nFine-tuning from scratch with grad_accum_step={self.current_grad_accum_step}")
            else:
                if self.local_rank == 0:       
                    print(f"Snapshot not found. Training model from scratch with grad_accum_step={self.current_grad_accum_step}")
            return 
    
        snapshot = Snapshot(**snapshot_data)
        self.raw_gato.load_state_dict(snapshot.model_state)
        self.optimizer.load_state_dict(snapshot.optimizer_state)
        self.scheduler.load_state_dict(snapshot.scheduler_state)
        self.current_grad_accum_step = self.scheduler.get_grad_accum_step()
        self.best_retrun = snapshot.best_retrun
        self.best_eval_loss = snapshot.best_eval_loss
        self.epoch_start = snapshot.finished_epoch  
        self.epoch_now = snapshot.finished_epoch
        self.wandb_id = snapshot.wandb_id
        self.trained_time = snapshot.trained_time
        if self.local_rank == 0:
            print(f"Resuming training from snapshot at Epoch {self.epoch_now} with grad_accum_step={self.current_grad_accum_step}")

    def _save_snapshot(self):
        snapshot = Snapshot(
            model_state=self.raw_gato.state_dict(),
            optimizer_state=self.optimizer.state_dict(),
            scheduler_state=self.scheduler.state_dict(),
            finished_epoch=self.epoch_now,
            best_retrun=self.best_retrun,
            best_eval_loss=self.best_eval_loss,
            trained_time=self.trained_time,
            wandb_id=self.wandb_id
        )
        snapshot = asdict(snapshot)
        torch.save(snapshot, self.snapshot_path)
        #print(f"Snapshot saved at epoch {self.epoch_now}")

    def _save_checkpoint(self, ave_return=-1e5, eval_loss=1e5):
        if self.args.save_strategy == 'interval':
            torch.save(
                self.raw_gato.state_dict(),
                f'{self.args.save_dir}/interval/{self.seed}/{round(ave_return,2)}_seed{self.seed}_epoch{self.epoch_now}.pt'
            )
        elif self.args.save_strategy == 'best':
            if not self.args.is_obs_pretrain and ave_return > self.best_retrun:
                self.best_retrun = ave_return
                torch.save(
                    self.raw_gato.state_dict(),
                    f'{self.args.save_dir}/best/{round(self.best_retrun,3)}_seed{self.seed}_epoch{self.epoch_now}.pt'
                )
            if self.args.is_obs_pretrain and eval_loss < self.best_eval_loss: 
                self.best_eval_loss = eval_loss         
                torch.save(
                    self.raw_gato.state_dict(),
                    f'{self.args.save_dir}/best/{round(self.best_eval_loss,3)}_seed{self.seed}_epoch{self.epoch_now}.pt'
                )
        else:
            raise NotImplementedError

    def _eval_policy(self, sample_action=False, hard_action_constraint=False):
        desc_sample = 'sample' if sample_action else 'greedy'
        desc_constraint = 'constraint' if hard_action_constraint else 'free'
        desc = f'{desc_sample}-{desc_constraint}'
        episode_return = {k: {'AM':0, 'DB1':0} for k in self.args.eval_dataset_names}
        episode_obj = {k: 0 for k in self.args.eval_dataset_names}
        episode_safe_ratio = {k: [] for k in self.args.eval_dataset_names}
        episode_time = {k: [] for k in self.args.eval_dataset_names}

        for env_name, dataset_name, env in zip(self.args.eval_env_names, self.args.eval_dataset_names, self.envs):
            if self.args.use_ddp_env:                
                ave_return, ave_obj, obj_std, ave_safe_ratio, ave_time_used, episode = evalute_batch_episode(
                    args=self.args, 
                    model=self.gato, 
                    env=env,
                    problemloader=self.problemloader_dict[env_name],
                    cont_tokenizer=self.cont_tokenizer,
                    sample_action=sample_action,
                    hard_action_constraint=hard_action_constraint,
                    desc=f'[GPU{self.global_rank}]: Evaluating on {env_name} ({desc})',
                    device=self.local_rank        
                )
                
                # render greedy policy if necessary
                if self.logger is not None:
                    for i, epi in enumerate(episode):
                        self.logger[env_name].log_episode(
                            desc=desc,
                            is_eval=True,
                            episode=epi, 
                            epoch_num=0, 
                            episode_num=i,
                            time_used=-1,
                            seed=self.seed
                        )
            else:
                problemloader = self.problemloader_dict[env_name]
                problemloader.reset()
                iters = self.args.problem_batch_num * self.args.problem_batch_size
                epi_return, epi_obj, epi_safe, epi_time = {'AM':[], 'DB1':[]}, [], [], []                
                with tqdm(total=iters, desc=f'[GPU{self.global_rank}]: Evaluating on {env_name} ({desc})', position=self.local_rank) as pbar:
                    for i in range(iters):
                        problem_info, problem_obj = problemloader.get_problem(1)
                        problem_info = (None if problem_info[0] is None else problem_info[0][0], problem_info[1][0], problem_info[2][0])
                        problem_obj = (problem_obj[0].item(), problem_obj[1].item())
                        
                        # eval tasks
                        ep_ret, ep_obj, ep_safe, ep_len, ep_time, epi = evalute_one_episode(
                            args=self.args, 
                            model=self.gato, 
                            env=env, 
                            cont_tokenizer=self.cont_tokenizer,
                            sample_action=sample_action,
                            hard_action_constraint=hard_action_constraint,
                            problem_info=problem_info,
                            problem_obj=problem_obj,
                            device=self.local_rank        
                        )
                        epi_safe.append(ep_safe)
                        epi_time.append(ep_time)
                        if ep_safe:
                            epi_return['AM'].append(ep_ret['AM'])
                            epi_return['DB1'].append(ep_ret['DB1'])
                            epi_obj.append(ep_obj)
                            
                        # render greedy policy if necessary
                        if self.logger is not None and self.args.policy_logger:    
                            self.logger[env_name].log_episode(
                                desc=desc,
                                is_eval=False,
                                episode=epi, 
                                epoch_num=self.epoch_now, 
                                episode_num=i,
                                time_used=ep_time,
                                seed=self.seed
                            )

                        ave_safe_ratio = 0 if epi_safe == [] else np.mean(epi_safe)
                        ave_time_used = 0 if epi_time == [] else np.mean(epi_time)
                        ave_obj = 0 if epi_obj == [] else np.mean(epi_obj)
                        obj_std = 0 if epi_obj == [] else np.std(epi_obj)
                        ave_return_AM = 0 if epi_return['AM'] == [] else np.mean(epi_return['AM'])
                        ave_return_DB1 = 0 if epi_return['DB1'] == [] else np.mean(epi_return['DB1'])
                        ave_return = {'AM':ave_return_AM, 'DB1':ave_return_DB1}

                        info = {
                            'ret_AM': f'{ave_return_AM:.4f}',
                            'ret_DB1': f'{ave_return_DB1:.4f}',
                            'obj' : f'{ave_obj:.4f}',
                            'std': f'{obj_std:.4f}',
                            'time': f'{ave_time_used:.2f}',
                        }
                        pbar.set_postfix(info)
                        pbar.update()
                
            episode_return[dataset_name] = ave_return
            episode_obj[dataset_name] = ave_obj
            episode_safe_ratio[dataset_name] = ave_safe_ratio
            episode_time[dataset_name] = ave_time_used
        #print('')
        return episode_return, episode_obj, episode_safe_ratio, episode_time

    def _run_batch(self, rl_task_input, batch_dataset_name, batch_raw_obs, is_train=False) -> float:
        #with torch.set_grad_enabled(is_train), torch.cuda.amp.autocast(dtype=torch.float16, enabled=(self.args.use_amp)):
        _, loss, loss_datasets, _ = self.gato(rl_task_input, batch_dataset_name=batch_dataset_name, batch_raw_obs=batch_raw_obs)
        
        new_lr = None
        total_norm = None
        if is_train:
            assert self.args.use_amp is False   # NOTE(XXX): current implement of use_amp hurt the performence too much
            if self.args.use_amp: 
                self.optimizer.zero_grad()
                self.scaler.scale(loss).backward()
                if self.args.clip_grad != 0:
                    total_norm = torch.nn.utils.clip_grad_norm_(self.gato.parameters(), self.args.clip_grad)
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                loss_ave = loss / self.current_grad_accum_step
                if self.grad_batch_cnt % self.current_grad_accum_step != 0:
                    with self.gato.no_sync():
                        loss_ave.backward()
                else:
                    loss_ave.backward()
                    if self.args.clip_grad != 0:
                        total_norm = torch.nn.utils.clip_grad_norm_(self.gato.parameters(), self.args.clip_grad)
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                    
                    new_lr, grad_accum_step = self.scheduler.step(self.current_grad_accum_step)
                    self.current_grad_accum_step = grad_accum_step
                    self.grad_batch_cnt = 0
            self.grad_batch_cnt += 1

        return loss.item(), loss_datasets, new_lr, total_norm

    def _run_epoch(self, is_train=False):
        epoch_losses = []
        epoch_losses_dataset = {dataset_name: [] for dataset_name in self.args.eval_dataset_names}
        data_info = {dataset_name: [] for dataset_name in self.args.eval_dataset_names}
        desc = f'Trianing Epoch {self.epoch_now}' if is_train else 'Calculating eval loss'
        total = self.args.batch_num if is_train else self.args.eval_batch_num
        dataloader = self.dataloader_train if is_train else self.dataloader_val
        dataloader.sampler.set_epoch(self.epoch_now, is_train)  # when using DistributedSampler, its necessary to set_epoch to shuffle
        self.optimizer.zero_grad()

        desc = f'[GPU{self.global_rank}]: ' + desc
        new_lr = 0
        total_norm = 0
        with tqdm(total=total, desc=desc, position=self.local_rank) as pbar:
            for batch in dataloader:
                # 加载 batch data
                rl_task_input, batch_data_info, batch_raw_obs = batch
                if self.args.auto_batch_len and rl_task_input.seq_len.max() < self.args.n_position:
                    rl_task_input.apply(lambda x: x[:, :rl_task_input.seq_len.max()] if isinstance(x, torch.Tensor) and x.dim() == 2 else x)
                    assert (rl_task_input.tensor_seq[:,-1]==self.args.special_tokens['<|>']).sum() >= 1

                # 记录数据集访问信息
                batch_dataset_name = []
                for dataset_name, dataset_idx in batch_data_info:
                    data_info[dataset_name].append(dataset_idx)
                    batch_dataset_name.append(dataset_name)

                # run batch
                rl_task_input.to(device=self.local_rank)
                loss, loss_datasets, lr, norm = self._run_batch(rl_task_input, batch_dataset_name, batch_raw_obs, is_train)
                new_lr = lr if lr is not None else new_lr
                total_norm = norm.item() if norm is not None else total_norm

                # 记录总损失和各数据集上损失信息
                epoch_losses.append(loss)
                for dataset_name, dataset_loss in loss_datasets.items():
                    if dataset_loss != 0:
                        epoch_losses_dataset[dataset_name].append(dataset_loss.item())

                # 更新进度条
                pbar.set_postfix({
                    'total norm': '{:.4f}'.format(total_norm),
                    'batch_token': '{:.4f}'.format(self.current_grad_accum_step * self.args.batch_size * self.world_size * self.args.n_position/1e6),
                    'loss':'{:.2f}'.format(loss), 
                    'ave loss (latest 20)': '{:.2f}'.format(np.array(epoch_losses[-20:]).mean()),
                })
                pbar.update()
                
                # 归集部分 batch 运行信息上传到 wandb
                batch_info_tensor = torch.tensor([new_lr, total_norm]).to(self.local_rank)
                batch_info_gather_list = [torch.zeros_like(batch_info_tensor) for _ in range(self.world_size)]
                dist.barrier()
                dist.gather(
                    batch_info_tensor,
                    batch_info_gather_list if self.local_rank == 0 else None, 
                    dst = 0
                )                
                if self.local_rank == 0:
                    log_value_tensor = torch.mean(torch.stack(batch_info_gather_list), axis=0, dtype=torch.float32)
                    batch_log_dict = {
                        'info/total_norm': log_value_tensor[1].item(),
                        'info/batch_token': self.current_grad_accum_step * self.args.batch_size * self.world_size * self.args.n_position/1e6
                    }
                    wandb.log(batch_log_dict)

        epoch_loss = np.array(epoch_losses).mean()
        epoch_loss_dataset = {name: 0 if len(losses) == 0 else np.array(losses).mean() for name, losses in epoch_losses_dataset.items()}
        return epoch_loss, epoch_loss_dataset, new_lr, data_info

    def train(self):    
        if self.local_rank == 0:
            data_visited_cnt = {dataset.dataset_name: np.zeros(len(dataset), dtype=np.int8) for dataset in self.dataset_train.datasets}   
        early_stop_signal = torch.tensor([0,]).to(self.local_rank)

        # start training
        for epoch in range(self.epoch_start, self.args.train_iters + 1):  
            self.epoch_now = epoch
            cst_ave_return = 0
            log_dict = {}

            # Save snapshot before the epoch training process, so that there is no overlap when recover from the snapshot
            if self.args.save_snapshot and epoch % self.args.snapshot_save_interval == 0 \
                and self.local_rank == 0 and epoch != 0 :
                self._save_snapshot()
            
            # Calculate validation losses at specified epoch intervals
            if (epoch % self.args.eval_interval == 0 or epoch == self.args.train_iters) and (epoch != 0 or not self.args.skip_first_eval):
                self.gato.eval()
                with torch.no_grad():
                #with torch.inference_mode():
                    # validation losses
                    self.raw_gato.transformer.same_length = False               # use normal context length when loss calculating (TransformerXL back bone)
                    eval_loss, eval_loss_dataset, _, _ = self._run_epoch(is_train=False)
                    log_dict.update({"losses/eval_loss": eval_loss})
                    log_dict.update(
                        {f'eval_{dataset_name[:-3]}/eval_loss': loss
                        for dataset_name, loss in eval_loss_dataset.items()}
                    )

            # Evaluate policy performance at specified epoch intervals
            if (epoch % self.args.eval_policy_interval == 0 or epoch == self.args.train_iters) and (epoch != 0 or not self.args.skip_first_eval):
                self.gato.eval()
                with torch.inference_mode():  
                    self.raw_gato.transformer.same_length = self.args.use_mem   # use fixed context length when rollout with mem (TransformerXL back bone)
                    cst_returns = []
                    for setting, eval_func in self.eval_setting.items():
                        epi_return, epi_obj, epi_safe, epi_time = eval_func()
                        log_dict.update(
                            {f'eval_{dataset_name[:-3]}/return_AM({setting})': np.mean(epi_return[dataset_name]['AM'])
                            for dataset_name in self.args.eval_dataset_names}
                        )
                        log_dict.update(
                            {f'eval_{dataset_name[:-3]}/return_DB1({setting})': np.mean(epi_return[dataset_name]['DB1'])
                            for dataset_name in self.args.eval_dataset_names}
                        )
                        log_dict.update(
                            {f'eval_{dataset_name[:-3]}/obj({setting})': np.mean(epi_obj[dataset_name])
                            for dataset_name in self.args.eval_dataset_names}
                        )
                        log_dict.update(
                            {f'eval_{dataset_name[:-3]}/safe({setting})': np.mean(epi_safe[dataset_name])
                            for dataset_name in self.args.eval_dataset_names}
                        )
                        log_dict.update(
                            {f'eval_{dataset_name[:-3]}/time({setting})': np.mean(epi_time[dataset_name])
                            for dataset_name in self.args.eval_dataset_names}
                        )

                        if setting.endswith('constraint'):
                            cst_returns.extend([v['AM'] for v in epi_return.values()])  # 用 AM return 作为 ckpt 质量指标
                    cst_ave_return = np.mean(cst_returns)
                    
            if epoch == self.args.train_iters:
                break

            # one training epoch
            self.gato.train()
            self.raw_gato.transformer.same_length = False               # use normal context length when loss calculating (TransformerXL back bone)
            start_time = time.time()
            train_loss, train_loss_dataset, new_lr, data_info = self._run_epoch(is_train=True)
            self.trained_time += time.time() - start_time
            log_dict.update(
                {f'eval_{dataset_name[:-3]}/train_loss': loss
                for dataset_name, loss in train_loss_dataset.items()}
            )
            log_dict.update({
                "losses/train_loss": train_loss, 
                'info/trained_time': self.trained_time,
                'info/MACs': self.model_macs_per_batch * self.args.batch_num * (self.epoch_now+1),
            })
            if new_lr is not None:
                log_dict.update({'info/lr': new_lr})

            # log train data if necessary
            if self.logger is not None and self.args.traindata_logger:
                logged_data = self.dataset_train.get_logged_data()
                for env_name in self.args.eval_env_names:
                    self.logger[env_name].log_data(logged_data, seed=self.seed, is_train=True)

            #import pprint
            #pp = pprint.PrettyPrinter(indent=4)
            #print('='*10+f'{self.local_rank}'+'='*10)
            #pp.pprint({k: v for k,v in log_dict.items()})
            #print()
            
            # gather info from all GPU
            log_value_tensor = torch.tensor([log_dict[k] for k in sorted(log_dict)] + [cst_ave_return,]).to(self.local_rank)
            gather_list = [torch.zeros_like(log_value_tensor) for _ in range(self.world_size)]
            dist.barrier()
            dist.gather(
                log_value_tensor,
                gather_list if self.local_rank == 0 else None, 
                dst=0
            )

            # gather data info from all GPU
            dataset_info = [len(data_info[dataset_name]) for dataset_name in self.args.eval_dataset_names]
            for dataset_name in self.args.eval_dataset_names:
                dataset_info += data_info[dataset_name]

            dataset_info_tensor = torch.tensor(dataset_info).to(self.local_rank)
            dataset_info_gather_list = [torch.zeros_like(dataset_info_tensor) for _ in range(self.world_size)]
            dist.barrier()
            dist.gather(
                dataset_info_tensor,
                dataset_info_gather_list if self.local_rank == 0 else None, 
                dst = 0
            )

            if self.local_rank == 0:
                data_info_gather = {dataset_name: [] for dataset_name in self.args.eval_dataset_names}   
                for info in dataset_info_gather_list:
                    info_len, info = info[:len(self.dataset_train.datasets)], info[len(self.dataset_train.datasets):]
                    for l, dataset_name in zip(info_len, self.args.eval_dataset_names):
                        data_info_gather[dataset_name].append(info[:l])
                        info = info[l:]
                
                epoch_data_num = 0
                for dataset_name in self.args.eval_dataset_names:
                    dataset_idxs = torch.cat(data_info_gather[dataset_name])        # 合并各个卡在当前 epoch 中访问的数据索引
                    #unique_idxs, _ = torch.unique(dataset_idxs, return_inverse=True)
                    #assert len(unique_idxs) == dataset_idxs.numel()                # 各个卡访问的索引应该没有重叠（混合数据集时不一定）
                    epoch_data_num += len(dataset_idxs)
                    data_visited_cnt[dataset_name][dataset_idxs.cpu()] += 1
                assert epoch_data_num >= self.args.batch_size * self.args.batch_num # 多卡无法均分数据时默认向上取整

            # only do file saving and logging at rank0
            if self.local_rank == 0:
                log_value_tensor = torch.mean(torch.stack(gather_list), axis=0, dtype=torch.float32)
                cst_ave_return = log_value_tensor[-1].item()

                # save ckpt for policy learning
                if self.args.save_ckpt and not self.args.is_obs_pretrain and epoch != 0 and epoch % self.args.save_interval == 0:
                    #assert cst_ave_return != 0
                    self._save_checkpoint(ave_return=cst_ave_return)

                # log to wandb
                for i, key in enumerate(sorted(log_dict)):
                    log_dict[key] = log_value_tensor[i].item()
                    if key == 'losses/eval_loss':
                        self.early_stopping(log_value_tensor[i].item())
                        # save ckpt for pretraining
                        if self.args.save_ckpt and self.args.is_obs_pretrain and epoch != 0:
                            self._save_checkpoint(eval_loss=log_value_tensor[i].item())

                #print('='*10+f'summary'+'='*10)
                #pp.pprint({k: v for k,v in log_dict.items()})
                #print()

                log_dict.update({"info/epoch": epoch})
                for dataset_name in self.args.eval_dataset_names:
                    dataset_visited_cnt = data_visited_cnt[dataset_name]
                    visited_part = dataset_visited_cnt[dataset_visited_cnt!=0]

                    log_name = dataset_name[:dataset_name.find('_')]
                    # 训练数据集被访问比例
                    log_dict.update({f'info/{log_name}_ratio': len(visited_part)/len(dataset_visited_cnt)})
                    # 训练数据集中被访问子集的尺寸
                    wandb.run.summary[f"info/{log_name}_visited"] = len(visited_part)           
                    # 训练数据集访问计数            
                    wandb.run.summary[f"info/{log_name}_num"] = np.sum(visited_part)
                    # 训练数据集中被访问子集的访问次数
                    wandb.run.summary[f"info/{log_name}_times"] = np.sum(visited_part)/len(visited_part)    

                wandb.log(log_dict)

                # early stopping
                if self.args.use_early_stopping and self.early_stopping.early_stop:
                    early_stop_signal = torch.tensor([1,]).to(self.local_rank)
                    dist.broadcast(tensor=early_stop_signal, src=0)
            
            if early_stop_signal:
                print(f'[GPU{self.local_rank}] Early Stop')
                break