import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(base_path)

import math
from environment.wrapper import LMPromptEnv
from environment.DDP_wrapper import DDP_LMPromptEnv
import numpy as np
import torch
import time
import random
from tqdm import tqdm
from dataloader.code.problem_loader import DDP_ProblemLoader
from argparse import Namespace
from typing import Optional, Union, List
from torch.nn import functional as F
from dataloader.code.tokenizer import ContinuousScalarTokenizer
from dataloader.code.dataset import get_loss_flag_and_position_id
from dataloader.code.input_specs import RLTaskInput
from utils.utils import COP_FAILED_RWD

def masked_logits_for_action(
    args,
    logits,                                     # (batch_size, n_embed)
    env_action_mask: List[np.ndarray] = None,   # (batch_size, total_vocab_size)
):
    """MASK LOGITS TO MAKE IT PREDICT ACTION TOKEN"""
    assert env_action_mask is not None
    #logits[..., args.num_discrete_values:] -= 1e10
    env_action_mask = env_action_mask * 1e10
    logits -= env_action_mask

    return logits

def truncate_memory(mems, obs_len, act_len):
    step_size = obs_len + act_len + 1
    res_mems = []
    for mem in mems:
        res_mems.append(mem[:, step_size:])
    return res_mems

def get_action(
    args,
    model,
    env,
    current_seq,
    cont_tokenizer,
    len_fixed_prompt,
    model_memory,
    prompt_strategy: str = "fixed_prompt",
    action_masks: List[np.ndarray] = None,
    sample_action: bool = False,
    obs_idxs: torch.tensor = None,
    raw_obs: dict = None,
    device: Optional[Union[int, str, torch.device]] = None
):
    prefix_dim = env.dataset.prefix_dim
    obs_dim = env.dataset.obs_dims_after_mlp_emb
    act_dim = env.dataset.act_dim
    trans_dim = env.dataset.trans_dim
    problem_batch_size = current_seq.shape[0]
    assert env.dataset.act_type_spec == 'int'
    batch_prefix_mask_list = env.get_prefix_mask()  # list of (act_dim, prefix_dim) or list of (prefix_dim, ) when act_dim == 1
    if act_dim > 1:
        batch_prefix_mask_list_by_act_dim = [[prefix_mask[i] for prefix_mask in batch_prefix_mask_list] for i in range(act_dim)]
    else:
        batch_prefix_mask_list_by_act_dim = [batch_prefix_mask_list, ]

    # Generate action vectors dimension by dimension
    batch_action_seq = []
    for i_act in range(act_dim):
        #if i_act == 0 or model_memory is None:
        #    loss_flag, pos_id = get_loss_flag_and_position_id(current_seq.shape[-1], prefix_dim, obs_dim, act_dim,)
        #else:
        #    pos_id = np.array([0])
        loss_flag, pos_id = get_loss_flag_and_position_id(
            current_seq.shape[-1]+1, 
            prefix_dim if args.use_prefix else None,
            obs_dim, act_dim,
        )
        loss_flag, pos_id = loss_flag[1:], pos_id[:-1]
        pos_id = pos_id[:, None].repeat(problem_batch_size, axis=1).T
        pos_id = torch.tensor(pos_id, dtype=torch.long)
        loss_flag = loss_flag[:, None].repeat(problem_batch_size, axis=1).T
        loss_flag = torch.tensor(loss_flag, dtype=torch.long)

        # Constructs current_seq as an RLTaskInput that the model can enter
        x = RLTaskInput(
            tensor_seq=current_seq, # (problem_batch_size, seq_len, ) Note that when generating the first action (using all zero mem), seq_len can be longer than 1024
            position_id=pos_id,     # (problem_batch_size, seq_len)
            attention_mask=torch.ones_like(current_seq, dtype=torch.int8),
            text_seq=None,
            obs_idxs=obs_idxs,      # (seq_len, )
            loss_mask=loss_flag,
            label=None,
            seq_len=None,
            prefix_mask=batch_prefix_mask_list_by_act_dim[i_act],
        )
        x.to(device=device)

        # model one step generating
        logits, _, _, new_mems = model(
            x, compute_loss=False, 
            mems=model_memory, 
            batch_dataset_name=[env.dataset.dataset_name,], 
            batch_raw_obs={k:torch.tensor(v) for k, v in raw_obs.items()}
        )                                   # logits: (problem_batch_size, seq_len, total_vocab_size)
        if model_memory is not None:
            model_memory = new_mems         # n_layer * [(problem_batch_size, mem_len, n_embed)]
            assert model_memory[0].shape[1] == args.n_position
            # if model_memory[0].shape[1] >= args.n_position:
            #     model_memory = truncate_memory(model_memory, obs_dim, act_dim)
        
        # Set mask according to the output space of current generating dim of action 
        action_masks = env.get_action_mask(hard_action_constraint=True, generated_actions=np.array(batch_action_seq).T)[i_act],
        action_masks = torch.from_numpy(np.array(action_masks)).to(device).squeeze()
        action_masks = action_masks[None, :] if action_masks.ndim == 1 else action_masks    # (batch_size, total_vocab_size)
                                
        logits = masked_logits_for_action(                                          # (problem_batch_size, total_vocab_size)
            args, 
            logits[:, -1, :],  
            env_action_mask = action_masks
        )                                                                                                             
        probs = F.softmax(logits, dim=-1)                                           # (problem_batch_size, total_vocab_size)

        # get pred action token
        if sample_action:
            pred_token = torch.multinomial(probs, num_samples=1)                    # (problem_batch_size, 1)
        else:
            _, pred_token = torch.topk(probs, k=1, dim=-1)                          # (problem_batch_size, 1)
        
        # set pred_token as 0 for the problem which have been sloved
        if action_masks is not None:
            pred_token[action_masks.all(dim=1)] = 0

        # update current_seq
        if model_memory is None:
            current_seq = current_seq.to(device=device)
            pred_token = pred_token.to(device=device)
            current_seq = torch.cat([current_seq, pred_token], dim=1)         # (problem_batch_size, seq_len,)
            if current_seq.shape[1] > args.n_position:
                raw_obs = {k:v[:,1:,:] for k, v in raw_obs.items()}
                if (args.use_prompt and prompt_strategy == "moving_prompt") or (not args.use_prompt and not args.use_prefix):
                    # 不用 prefix/prompt 或使用 "moving_prompt"，则随着序列增长不断保留尾部序列, 这样 prompt 序列会不断更新
                    current_seq = current_seq[:, trans_dim:]
                elif args.use_prefix:
                    # 如果设置 prefix, 则维持序列首部的 prefix 不变
                    window_seq_view = torch.roll(current_seq[:,prefix_dim+1:], -trans_dim)    # 将 current_seq 中除 prefix 以外的序列循环左移 prefix_dim
                    current_seq[:,prefix_dim+1:].data.copy_(window_seq_view.data)             # 除 prefix 以外的序列中，首部那个 transition 对应的序列放在尾部
                    current_seq = current_seq[:,:-trans_dim]                                  # 把上一步放到尾部的（原首部）transition对应序列去除    
                else:
                    raise NotImplementedError
        else:
            # although cpu() may have a new copy, prevent side effect
            # of recover_model_predict_token_to_tokenizer_raw where
            # there are some inplace operations
            # XXX(DB1): memory net uses moving prompt!
            assert prompt_strategy != "fixed_prompt"
            current_seq = pred_token.to(device=device)                                       

        # recover model predict token to tokenizer raw
        batch_action_seq.append(pred_token.squeeze().cpu())

    '''
    for i in range(len(current_seq)):
        print(i, current_seq[i])
    '''
    
    # pass last dim of action to model to update model_memory
    if model_memory is not None:
        x = RLTaskInput(
            tensor_seq=current_seq,                                             # (problem_batch_size, 1) 
            position_id=torch.zeros((problem_batch_size, 1), dtype=torch.long), # (problem_batch_size, 1)
            attention_mask=torch.ones_like(current_seq, dtype=torch.int8),
            text_seq=None,
            loss_mask=None,
            label=None,
            seq_len=None,
            prefix_mask=env.get_prefix_mask(),
        )
        x.to(device=device)

        _, _, _, model_memory = model(x, compute_loss=False, mems=model_memory)
        assert model_memory[0].shape[1] == args.n_position
        # if model_memory[0].shape[1] >= args.n_position:
        #     model_memory = truncate_memory(model_memory, obs_dim, act_dim)

    # token -> action
    act = np.vstack(batch_action_seq).T                     # (problem_batch_size, act_dim)             
    act = env.dataset.adapter.recover_raw_act(act)          # (problem_batch_size, act_dim)
    return act, current_seq, model_memory, raw_obs

#@torch.no_grad()
@torch.inference_mode()
def evalute_one_episode(
    args: Namespace,
    model: torch.nn.Module,
    env: LMPromptEnv,
    cont_tokenizer: ContinuousScalarTokenizer,
    sample_action: bool = False,
    hard_action_constraint: bool = False,
    regen_times: int = 1,
    problem_info: tuple = None,
    problem_obj: tuple = None,
    device: Optional[Union[int, str, torch.device]] = None
):
    assert not args.use_ddp_env
    eval_prompt_strat = args.prompt_strategy.split(";")[-1]     # moving_prompt
    assert args.use_prefix ^ (args.use_prompt and eval_prompt_strat == "moving_prompt")
    regen_times = regen_times + 1 if regen_times == 0 else regen_times
    max_step_size=args.eval_max_step_size
    prefix_dim = env.dataset.prefix_dim
    obs_dim = env.dataset.obs_dims_after_mlp_emb
    act_dim = env.dataset.act_dim
    trans_dim = env.dataset.trans_dim
    device = device if device is not None else torch.device(f"cuda:{args.device[0]}" if torch.cuda.is_available() and torch.cuda.device_count() >= args.device[0]+1 else "cpu")
    spliter_token = torch.tensor([env.dataset.spliter_token_id], dtype=torch.long)
    prefix_spliter_token = torch.tensor([args.special_tokens['<X>']], dtype=torch.long)
    
    # 对于 COPTask 多次生成解进行自洽性投票，若为 RLTask 则 regen_times==1
    results_acts, results_cnt, results_info = [], [], []
    time_start = time.time()
    for _ in range(regen_times):   
        # reset env 
        current_seq, info = env.reset(options={
            'problem_info': problem_info,
            'problem_obj': problem_obj,
            'use_default_policy_obj': args.use_default_policy_obj
        })

        # set prompt token sequence
        len_fixed_prompt = 0
        raw_prefix = None
        if args.use_prefix:
            prefix_tensor, raw_prefix = env.get_prefix(with_raw=True)
            current_seq = torch.cat([prefix_tensor, prefix_spliter_token, current_seq, spliter_token])    # 拼接当前obs和spliter，下一步用于自回归生成action
            raw_obs = {k:v[None,:] for k,v in info['obs'].items()} 
        elif args.use_prompt:
            fixed_prompt, raw_obs = env.get_prompt(strict_length=args.strict_length, minimal_expert_data=args.minimal_expert_data)
            len_fixed_prompt = len(fixed_prompt)                                        # NOTE(XXX): prompt长度可能不等于模型上下文长度1024
            current_seq = torch.cat([fixed_prompt, current_seq, spliter_token])         # 拼接当前obs和spliter，下一步用于自回归生成action
            raw_obs = {k:np.vstack((v, info['obs'][k][None,:])) for k, v in raw_obs.items()}
        else:
            raw_obs = {k:v[None,:] for k,v in info['obs'].items()} 
            current_seq = torch.cat([current_seq, spliter_token])                       # 拼接当前obs和spliter，下一步用于自回归生成action
        
        trans_num = math.ceil((len(current_seq)-prefix_dim)/trans_dim)
        assert trans_num == len(list(raw_obs.values())[0])
        obs_idxs = ([0] * prefix_dim + env.dataset.obs_idxs * trans_num)[:len(current_seq)]
        obs_idxs = torch.tensor([obs_idxs,])                # (1, sql_len)
        current_seq = current_seq[None, :]                  # (1, sql_len)
        assert current_seq.shape[1] == obs_idxs.shape[1]

        # START EVALUATION
        raw_model = model.module if hasattr(model, "module") else model
        assert raw_model.transformer.same_length == args.use_mem
        model_memory = raw_model.transformer.init_mem(batch_size=1) if args.use_mem else None    # n_layer * [(batch_size, mem_len, n_embed)] 初始化为全 0 张量
        
        episode_return, episode_obj, episode_length = {'AM':0, 'DB1':0}, 0, 0
        if args.policy_logger:
            obss, acts, rewards, value_spaces = [info['obs'], ], [], {'AM':[], 'DB1':[]}, [env.env.get_action_value_space(hard_action_constraint), ]
        while True:
            # Generate action
            assert current_seq.shape[1] <= raw_model.transformer.n_position
            act, current_seq, model_memory, raw_obs = get_action(
                args=args,
                model=model,
                env=env,
                current_seq=current_seq,
                cont_tokenizer=cont_tokenizer,
                len_fixed_prompt=len_fixed_prompt,
                model_memory=model_memory,
                prompt_strategy=eval_prompt_strat,
                action_masks=env.get_action_mask(hard_action_constraint),
                sample_action=sample_action,
                obs_idxs=obs_idxs,
                raw_obs=raw_obs,
                device=device
            )
            act = act.item()                

            '''
            if model_memory is not None:
                with open(f'{base_path}/model/test.txt', 'a') as file:
                    for memory in model_memory:
                        file.write(str(memory[0,:,0].tolist()))
                        file.write('\n')
                    file.write('\n\n')
            else:
                with open(f'{base_path}/model/test.txt', 'a') as file:
                    for hid in model.transformer.hids:
                        file.write(str(hid[0,:,0].tolist()))
                        file.write('\n')
                    file.write('\n\n')
            '''

            # env one step update
            new_seq, reward, terminated, truncated, info = env.step(act)
            new_seq = new_seq.unsqueeze(0) if new_seq.ndim == 0 else new_seq
            episode_return['AM'] += reward['AM']
            episode_return['DB1'] += reward['DB1']
            episode_length += 1
            if args.policy_logger:
                acts.append(act)
                rewards['AM'].append(reward['AM'])
                rewards['DB1'].append(reward['DB1'])
                obss.append(info['obs'])
                value_spaces.append(env.env.get_action_value_space(hard_action_constraint))
            
            # exit point
            if terminated or truncated or (max_step_size is not None and episode_length >= max_step_size):
                episode_obj = info['obj'].item()
                if args.policy_logger:
                    obss.pop(-1)
                    value_spaces.pop(-1)
                break
            
            # update current_seq
            assert model_memory is None
            if model_memory is None:
                # 拼接新的观测和分隔符
                current_seq = torch.cat([current_seq.squeeze(), new_seq, spliter_token])       # (sql_len,)
                
                # 超长处理
                if len(current_seq) > args.n_position:                                        
                    if (args.use_prompt and eval_prompt_strat == "moving_prompt") or (not args.use_prompt and not args.use_prefix):
                        # 不用 prefix/prompt 或使用 "moving_prompt"，则随着序列增长不断保留尾部序列, 这样 prompt 序列会不断更新
                        current_seq = current_seq[trans_dim:]
                    elif args.use_prefix:
                        # 如果设置 prefix, 则维持序列首部的 prefix 不变
                        window_seq_view = torch.roll(current_seq[prefix_dim+1:], -trans_dim)        # 将 current_seq 中除 prefix 以外的序列循环左移 prefix_dim
                        current_seq[prefix_dim+1:].data.copy_(window_seq_view.data)                 # 除 prefix 以外的序列中，首部那个 transition 对应的序列放在尾部
                        current_seq = current_seq[:-trans_dim]                                      # 把上一步放到尾部的（原首部）transition对应序列去除
                    else:
                        raise NotImplementedError

                    raw_obs = {k:np.vstack((v[1:], info['obs'][k][None,:])) for k, v in raw_obs.items()}
                    current_seq = current_seq[None, :]                  # (1, sql_len)
                else:
                    raw_obs = {k:np.vstack((v, info['obs'][k][None,:])) for k, v in raw_obs.items()}
                    trans_num = math.ceil((len(current_seq)-prefix_dim)/trans_dim)
                    assert trans_num == len(list(raw_obs.values())[0])
                    obs_idxs = ([0] * prefix_dim + env.dataset.obs_idxs * trans_num)[:len(current_seq)]
                    obs_idxs = torch.tensor([obs_idxs,])                # (1, sql_len)
                    current_seq = current_seq[None, :]
                    assert current_seq.shape[1] == obs_idxs.shape[1]
            '''
            else:
                current_seq = torch.cat([new_seq, spliter_token])
            '''
            
        # 得到评估过程轨迹
        if args.policy_logger:
            if isinstance(info['obs'], dict):
                obs_dict = {}
                for k in obss[0].keys():
                    if obss[0][k].shape[0] == 1:
                        obs_dict[k] = np.concatenate([obs[k] for obs in obss]).astype(env.dataset.obs_type_spec[k])
                    else:
                        obs_dict[k] = np.vstack([obs[k] for obs in obss]).astype(env.dataset.obs_type_spec[k])
                obss = obs_dict
            else:
                obss = np.array(obss).astype(env.dataset.obs_type_spec)

            episode = {
                'prefix': raw_prefix,
                'observations': obss,
                'actions': np.array(acts).astype(env.dataset.act_type_spec),
                'rewards': {'AM': np.array(rewards['AM']).astype(np.float32), 'DB1': np.array(rewards['DB1']).astype(np.float32)},
                'act_value_space': value_spaces
            }
        else:
            episode = None

        # 自洽性结果投票，由于 COP 都是确定性环境，只要看动作序列是否一致即可
        ans_ret = episode_return
        if regen_times > 1:
            this_acts = np.concatenate(acts) if len(acts) > 1 else acts[0]
            is_new_res = True
            for i, epi_acts in enumerate(results_acts):
                if env.env.is_same_episode(epi_acts, this_acts):
                    results_cnt[i] += 1
                    is_new_res = False
            if is_new_res:
                results_acts.append(this_acts)
                results_info.append((episode_return, episode_obj, episode_length, episode))
                results_cnt.append(1)
    
    time_end = time.time()
    time_used = time_end - time_start

    if regen_times > 1:    
        results_cnt = np.array(results_cnt)
        max_vote = results_cnt[np.argmax(results_cnt)]
        max_idxs = np.where(results_cnt == max_vote)[0]
        max_idx = random.choice(max_idxs)
        ans_ret, episode_obj, episode_length, episode = results_info[max_idx]
    
    if args.policy_logger and env.env_name == 'Env_TSP_V3':
        episode['prefix'] = env.env.position.copy()

    ans_safe = 0 if env.task_type == 'COPTask' and ans_ret['AM'] == COP_FAILED_RWD else 1
    return ans_ret, episode_obj, ans_safe, episode_length, time_used, episode

@torch.inference_mode()
def evalute_batch_episode(
    args: Namespace,
    model: torch.nn.Module,
    env: DDP_LMPromptEnv,
    problemloader: DDP_ProblemLoader,
    cont_tokenizer: ContinuousScalarTokenizer,
    sample_action: bool = False,
    hard_action_constraint: bool = False,
    desc: str = '',
    device: Optional[Union[int, str, torch.device]] = None
):
    assert args.use_ddp_env
    eval_prompt_strat = args.prompt_strategy.split(";")[-1]     # moving_prompt
    assert args.use_prefix ^ (args.use_prompt and eval_prompt_strat == "moving_prompt") or args.use_prompt == args.use_prefix == False
    rank = int(os.environ.get("RANK", default='0'))
    prefix_dim = env.dataset.prefix_dim
    obs_dim = env.dataset.obs_dims_after_mlp_emb
    act_dim = env.dataset.act_dim
    trans_dim = env.dataset.trans_dim
    device = device if device is not None else torch.device(f"cuda:{args.device[0]}" if torch.cuda.is_available() and torch.cuda.device_count() >= args.device[0]+1 else "cpu")
    p_batch_size = args.problem_batch_size * 24 if env.env_name == 'Env_FFSP_V1' else args.problem_batch_size
    spliter_tokens = torch.full((p_batch_size, 1), args.special_tokens['<|>'], dtype=torch.long).to(device)
    prefix_spliter_tokens = torch.full((p_batch_size, 1), args.special_tokens['<X>'], dtype=torch.long).to(device)
    
    # get ready
    problemloader.reset()
    raw_model = model.module if hasattr(model, "module") else model
    assert raw_model.transformer.same_length == args.use_mem
    episodes = [] if args.policy_logger else None
        
    # START EVALUATION
    episode_return, episode_obj, episode_safe_cnt, episode_slove_cnt = {'AM':[], 'DB1':[]}, [], 0, 0
    time_start = time.time()
    iters = args.problem_batch_num * args.problem_batch_size
    with tqdm(total=iters, desc=desc, position=rank) as pbar:
        for _ in range(args.problem_batch_num):
            # reset env to set the eval problem  
            problem_info, problem_obj = problemloader.get_problem(args.problem_batch_size)
            problem_idx = list(range(args.problem_batch_size))
            current_seq, info = env.reset(options={     # current_seq: (problem_batch_size, obs_dim)  
                'problem_info':problem_info,
                'problem_obj':problem_obj, 
                'problem_idx':problem_idx,
                'use_default_policy_obj': args.use_default_policy_obj
            })       
            current_seq = current_seq.to(device)

            # prepare prompt or prefix
            if args.use_prompt:
                fixed_prompt, raw_obs = env.get_prompt(args.strict_length, args.minimal_expert_data)    # (problem_batch_size, len_prompt)
                fixed_prompt = fixed_prompt.to(device)
                current_seq = torch.cat([fixed_prompt, current_seq, spliter_tokens], dim=-1)            # (problem_batch_size, len_prompt+obs_dim+1) 拼接当前 obs 和 spliter，下一步用于自回归生成 action
                len_fixed_prompt = fixed_prompt.shape[1]                                                # NOTE(XXX): prompt长度可能不等于模型上下文长度1024
                raw_prefix = None
                raw_obs = {k: np.concatenate((v, info['obs'][k][:,None,:]), axis=1) if info['obs'][k].ndim==2 else \
                              np.concatenate((v, info['obs'][k][:,None,None]), axis=1) for k, v in raw_obs.items()}
            elif args.use_prefix:
                prefix_tensor, raw_prefix = env.get_prefix(with_raw=True)
                prefix_tensor = prefix_tensor.to(device)
                current_seq = torch.cat([prefix_tensor, prefix_spliter_tokens, current_seq, spliter_tokens], dim=-1)    # 拼接当前obs和spliter，下一步用于自回归生成action
                len_fixed_prompt = 0
                raw_obs = {k: v[:,None,:] if v.ndim==2 else v[:,None,None] for k,v in info['obs'].items()} 
            else:
                current_seq = torch.cat([current_seq, spliter_tokens], dim=-1)              # (problem_batch_size, obs_dim+1) 拼接当前obs和spliter，下一步用于自回归生成action
                len_fixed_prompt = 0
                raw_prefix = None
                raw_obs = {k: v[:,None,:] if v.ndim==2 else v[:,None,None] for k,v in info['obs'].items()} 

            trans_num = math.ceil((current_seq.shape[1]-prefix_dim)/env.dataset.trans_dim)
            assert trans_num == list(raw_obs.values())[0].shape[1]
            obs_idxs = ([0] * prefix_dim + env.dataset.obs_idxs * trans_num)[:current_seq.shape[1]]
            obs_idxs = torch.tensor([obs_idxs,])                # (1, sql_len)
            assert current_seq.shape[1] == obs_idxs.shape[1]

            # eval a batch of problem
            model_memory = None if not args.use_mem else \
                            raw_model.transformer.init_mem(args.problem_batch_size)         # n_layer * [(problem_batch_size, mem_len, n_embed)] 初始化为全 0 张量
            batch_done = np.zeros(args.problem_batch_size, dtype=bool)
            if args.policy_logger:
                batch_obss, batch_acts, batch_rewards, batch_value_spaces = [info['obs'], ], [], {'AM':[], 'DB1':[]}, [env.env.get_action_value_space(hard_action_constraint),]
            while batch_done.sum() < args.problem_batch_size:
                # Generate action
                assert current_seq.shape[1] <= raw_model.transformer.n_position
                act, current_seq, model_memory, raw_obs = get_action(
                    args=args,
                    model=model,
                    env=env,
                    current_seq=current_seq,
                    cont_tokenizer=cont_tokenizer,
                    len_fixed_prompt=len_fixed_prompt,
                    model_memory=model_memory,
                    prompt_strategy=eval_prompt_strat,
                    sample_action=sample_action,
                    obs_idxs=obs_idxs,
                    raw_obs=raw_obs,
                    device=device
                )
                act = act.squeeze()     # (problem_batch_size, act_dim) or (problem_batch_size, ) when act_dim==1
                #assert act.max() <= 20
                
                # env one step update
                new_seq, reward, terminated, truncated, info = env.step(act)    # new_seq: (problem_batch_size, obs_dim)
                assert reward['AM'][batch_done==1].sum() == 0, 'Some episode that have ended send back reward signals'

                done = terminated | truncated                                   # (problem_batch_size, sql_len)
                assert (batch_done + done).max() < 2, 'Some episode end more than one times'
                batch_done |= done                                              # (problem_batch_size, sql_len)
                #print(batch_done)

                episode_return['AM'].extend(reward['AM'][terminated==1].tolist())
                episode_return['DB1'].extend(reward['DB1'][terminated==1].tolist())
                episode_obj.extend(info['obj'][terminated==1].tolist())
                episode_safe_cnt += terminated.sum()
                episode_slove_cnt += done.sum()
                assert episode_safe_cnt == len(episode_obj) == len(episode_return['AM']) == len(episode_return['DB1'])

                if args.policy_logger:
                    batch_acts.append(act)
                    batch_rewards['AM'].append(reward['AM'])
                    batch_rewards['DB1'].append(reward['DB1'])
                    batch_obss.append(info['obs'])
                    batch_value_spaces.append(env.env.get_action_value_space(hard_action_constraint))

                # update current_seq
                assert model_memory is None
                if model_memory is None:
                    # 拼接新的观测和分隔符
                    new_seq = new_seq.to(device)
                    current_seq = torch.cat([current_seq, new_seq, spliter_tokens], dim=-1)         # (problem_batch_size, sql_len)
                    
                    # 超长处理
                    if current_seq.shape[1] > args.n_position:                                                
                        if (args.use_prompt and eval_prompt_strat == "moving_prompt") or (not args.use_prompt and not args.use_prefix):
                            # 不用 prefix/prompt 或使用 "moving_prompt"，则随着序列增长不断保留尾部序列, 这样 prompt 序列会不断更新
                            current_seq = current_seq[:, trans_dim:]
                        elif args.use_prefix:
                            # 如果设置 prefix, 则维持序列首部的 prefix 不变
                            window_seq_view = torch.roll(current_seq[:,prefix_dim+1:], -trans_dim)  # 将 current_seq 中除 prefix 以外的序列循环左移 prefix_dim
                            current_seq[:,prefix_dim+1:].data.copy_(window_seq_view.data)           # 除 prefix 以外的序列中，首部那个 transition 对应的序列放在尾部
                            current_seq = current_seq[:,:-trans_dim]                                # 把上一步放到尾部的（原首部）transition对应序列去除
                        else:
                            raise NotImplementedError

                        raw_obs = {
                            k: np.concatenate((v[:,1:,:], info['obs'][k][:,None,:]), axis=1) if info['obs'][k].ndim==2 else \
                                np.concatenate((v[:,1:,:], info['obs'][k][:,None,None]), axis=1) for k, v in raw_obs.items()
                        }
                    else:
                        raw_obs = {
                            k: np.concatenate((v, info['obs'][k][:,None,:]), axis=1) if info['obs'][k].ndim==2 else \
                                np.concatenate((v, info['obs'][k][:,None,None]), axis=1) for k, v in raw_obs.items()
                        }
                        trans_num = math.ceil((current_seq.shape[1]-prefix_dim)/env.dataset.trans_dim)
                        assert trans_num == list(raw_obs.values())[0].shape[1]
                        obs_idxs = ([0] * prefix_dim + env.dataset.obs_idxs * trans_num)[:current_seq.shape[1]]
                        obs_idxs = torch.tensor([obs_idxs,])                # (1, sql_len)
                        assert current_seq.shape[1] == obs_idxs.shape[1]
                '''   
                else:
                    current_seq = torch.cat([new_seq, spliter_tokens], dim=-1)
                '''

                env_return_AM = 0 if len(episode_return['AM']) == 0 else np.mean(episode_return['AM'])
                env_return_DB1 = 0 if len(episode_return['DB1']) == 0 else np.mean(episode_return['DB1'])
                env_obj = 0 if len(episode_obj) == 0 else np.mean(episode_obj)
                env_obj_std = 0 if len(episode_obj) == 0 else np.std(episode_obj)
                safe_ratio = 0 if episode_slove_cnt == 0 else episode_safe_cnt / episode_slove_cnt
                time_used = 0 if episode_slove_cnt == 0 else (time.time() - time_start) / episode_slove_cnt
                info = {
                    'ret_AM': f'{env_return_AM:.4f}',
                    'ret_DB1': f'{env_return_DB1:.4f}',
                    'obj': f'{env_obj:.2f}',
                    'std': f'{env_obj_std:.2f}',
                    'safe': f'{safe_ratio:.2f}',
                    'time': f'{time_used:.2f}',
                }
                pbar.set_postfix(info)
                pbar.update(done.sum())

            # 得到评估过程轨迹
            #if args.policy_logger:
            #    acts = np.vstack(batch_acts).T                      # (problem_batch_size, max_batch_timestep)
            #    rewards_AM = np.vstack(batch_rewards['AM']).T       # (problem_batch_size, max_batch_timestep)
            #    rewards_DB1 = np.vstack(batch_rewards['DB1']).T     # (problem_batch_size, max_batch_timestep)
            #    timestep_len = np.where(rewards_AM!=0)[1] + 1       # (problem_batch_size, )
            #    assert timestep_len.shape == (args.problem_batch_size, )
            #    
            #    for i in range(args.problem_batch_size):
            #        epi_len = timestep_len[i]                       # epi_len <= max_batch_timestep
            #        epi_spaces = [[space_timestep[0][i],] for space_timestep in batch_value_spaces[:epi_len]]
            #
            #        assert isinstance(env.dataset.obs_type_spec, dict)
            #        assert isinstance(env.dataset.prefix_type_spec, dict)
            #        obs_dict = {}
            #        for k, type_spec in env.dataset.obs_type_spec.items():
            #            obs_dict[k] = np.array([obs[k][i] for obs in batch_obss[:epi_len]]).astype(type_spec)
            #        
            #        episode = {
            #            'prefix': raw_prefix,
            #            'observations': obs_dict,
            #            'actions': acts[i][:epi_len].astype(env.dataset.act_type_spec),
            #            'rewards': {'AM': rewards_AM[i][:epi_len].astype(np.float32), 'DB1': rewards_DB1[i][:epi_len].astype(np.float32)},
            #            'act_value_space': epi_spaces
            #        }
            #        episodes.append(episode)
            
    env_return = {
        'AM': 0 if episode_return['AM'] == [] else np.mean(episode_return['AM']),
        'DB1': 0 if episode_return['DB1'] == [] else np.mean(episode_return['DB1'])        
    } 
    return env_return, env_obj, env_obj_std, safe_ratio, time_used, episodes

'''
@torch.inference_mode()
def evalute_batch_episode(
    args: Namespace,
    model: torch.nn.Module,
    env: DDP_LMPromptEnv,
    problemloader: DDP_ProblemLoader,
    cont_tokenizer: ContinuousScalarTokenizer,
    sample_action: bool = False,
    hard_action_constraint: bool = False,
    desc: str = '',
    device: Optional[Union[int, str, torch.device]] = None
):
    assert args.use_ddp_env
    eval_prompt_strat = args.prompt_strategy.split(";")[-1]     # moving_prompt
    assert args.use_prefix ^ (args.use_prompt and eval_prompt_strat == "moving_prompt") or args.use_prompt == args.use_prefix == False
    rank = int(os.environ.get("RANK", default='0'))
    prefix_dim = env.dataset.prefix_dim
    obs_dim = env.dataset.obs_dims_after_mlp_emb
    act_dim = env.dataset.act_dim
    trans_dim = env.dataset.trans_dim
    device = device if device is not None else torch.device(f"cuda:{args.device[0]}" if torch.cuda.is_available() and torch.cuda.device_count() >= args.device[0]+1 else "cpu")
    spliter_tokens = torch.full((args.problem_batch_size, 1), args.special_tokens['<|>'], dtype=torch.long).to(device)
    prefix_spliter_tokens = torch.full((args.problem_batch_size, 1), args.special_tokens['<X>'], dtype=torch.long).to(device)
    
    # get ready
    problemloader.reset()
    raw_model = model.module if hasattr(model, "module") else model
    assert raw_model.transformer.same_length == args.use_mem
    episodes = [] if args.policy_logger else None
        
    # START EVALUATION
    episode_return, episode_obj, episode_safe_cnt, episode_slove_cnt = {'AM':[], 'DB1':[]}, [], 0, 0
    time_start = time.time()
    iters = args.problem_batch_num * args.problem_batch_size
    with tqdm(total=iters, desc=desc, position=rank) as pbar:
        for _ in range(args.problem_batch_num):
            # reset env to set the eval problem  
            problem_info, problem_obj = problemloader.get_problem(args.problem_batch_size)
            problem_idx = list(range(args.problem_batch_size))
            current_seq, info = env.reset(options={     # current_seq: (problem_batch_size, obs_dim)  
                'problem_info':problem_info,
                'problem_obj':problem_obj, 
                'problem_idx':problem_idx,
                'use_default_policy_obj': args.use_default_policy_obj
            })       
            current_seq = current_seq.to(device)

            # prepare prompt or prefix
            if args.use_prompt:
                fixed_prompt, raw_obs = env.get_prompt(args.strict_length, args.minimal_expert_data)    # (problem_batch_size, len_prompt)
                fixed_prompt = fixed_prompt.to(device)
                current_seq = torch.cat([fixed_prompt, current_seq, spliter_tokens], dim=-1)            # (problem_batch_size, len_prompt+obs_dim+1) 拼接当前 obs 和 spliter，下一步用于自回归生成 action
                len_fixed_prompt = fixed_prompt.shape[1]                                                # NOTE(XXX): prompt长度可能不等于模型上下文长度1024
                raw_prefix = None
                raw_obs = {k: np.concatenate((v, info['obs'][k][:,None,:]), axis=1) if info['obs'][k].ndim==2 else \
                              np.concatenate((v, info['obs'][k][:,None,None]), axis=1) for k, v in raw_obs.items()}
            elif args.use_prefix:
                prefix_tensor, raw_prefix = env.get_prefix(with_raw=True)
                prefix_tensor = prefix_tensor.to(device)
                current_seq = torch.cat([prefix_tensor, prefix_spliter_tokens, current_seq, spliter_tokens], dim=-1)    # 拼接当前obs和spliter，下一步用于自回归生成action
                len_fixed_prompt = 0
                raw_obs = {k: v[:,None,:] if v.ndim==2 else v[:,None,None] for k,v in info['obs'].items()} 
            else:
                current_seq = torch.cat([current_seq, spliter_tokens], dim=-1)              # (problem_batch_size, obs_dim+1) 拼接当前obs和spliter，下一步用于自回归生成action
                len_fixed_prompt = 0
                raw_prefix = None
                raw_obs = {k: v[:,None,:] if v.ndim==2 else v[:,None,None] for k,v in info['obs'].items()} 

            trans_num = math.ceil((current_seq.shape[1]-prefix_dim)/env.dataset.trans_dim)
            assert trans_num == list(raw_obs.values())[0].shape[1]
            obs_idxs = ([0] * prefix_dim + env.dataset.obs_idxs * trans_num)[:current_seq.shape[1]]
            obs_idxs = torch.tensor([obs_idxs,])                # (1, sql_len)
            assert current_seq.shape[1] == obs_idxs.shape[1]

            # eval a batch of problem
            model_memory = None if not args.use_mem else \
                            raw_model.transformer.init_mem(args.problem_batch_size)         # n_layer * [(problem_batch_size, mem_len, n_embed)] 初始化为全 0 张量
            batch_done = np.zeros(args.problem_batch_size, dtype=bool)
            if args.policy_logger:
                batch_obss, batch_acts, batch_rewards, batch_value_spaces = [info['obs'], ], [], {'AM':[], 'DB1':[]}, [env.env.get_action_value_space(hard_action_constraint),]
            while batch_done.sum() < args.problem_batch_size:
                # Generate action
                batch_action = []
                for d in range(act_dim):
                    generated_actions = np.array(batch_action).T
                    ddp_action_space = env.env.get_action_value_space(hard_action_constraint, generated_actions)[d]   # [batch_size * [available job idx]]
                    actions = [0 if len(ddp_action_space[i])==0 else np.random.choice(ddp_action_space[i]) for i in range(len(ddp_action_space))]
                    batch_action.append(actions)

                    assert current_seq.shape[1] <= raw_model.transformer.n_position
                    act, current_seq, model_memory, raw_obs = get_action(
                        args=args,
                        model=model,
                        env=env,
                        current_seq=current_seq,
                        cont_tokenizer=cont_tokenizer,
                        len_fixed_prompt=len_fixed_prompt,
                        model_memory=model_memory,
                        prompt_strategy=eval_prompt_strat,
                        action_masks=env.get_action_mask(hard_action_constraint),
                        sample_action=sample_action,
                        obs_idxs=obs_idxs,
                        raw_obs=raw_obs,
                        device=device
                    )
                    act = act.squeeze()     # (problem_batch_size,)
                    #assert act.max() <= 20

                # env one step update
                new_seq, reward, terminated, truncated, info = env.step(act)    # new_seq: (problem_batch_size, obs_dim)
                assert reward['AM'][batch_done==1].sum() == 0, 'Some episode that have ended send back reward signals'

                done = terminated | truncated                                   # (problem_batch_size, sql_len)
                assert (batch_done + done).max() < 2, 'Some episode end more than one times'
                batch_done |= done                                              # (problem_batch_size, sql_len)

                episode_return['AM'].extend(reward['AM'][terminated==1].tolist())
                episode_return['DB1'].extend(reward['DB1'][terminated==1].tolist())
                episode_obj.extend(info['obj'][terminated==1].tolist())
                episode_safe_cnt += terminated.sum()
                episode_slove_cnt += done.sum()
                assert episode_safe_cnt == len(episode_obj) == len(episode_return['AM']) == len(episode_return['DB1'])

                if args.policy_logger:
                    batch_acts.append(act)
                    batch_rewards['AM'].append(reward['AM'])
                    batch_rewards['DB1'].append(reward['DB1'])
                    batch_obss.append(info['obs'])
                    batch_value_spaces.append(env.env.get_action_value_space(hard_action_constraint))

                # update current_seq
                assert model_memory is None
                # 拼接新的观测和分隔符
                new_seq = new_seq.to(device)
                current_seq = torch.cat([current_seq, new_seq, spliter_tokens], dim=-1)         # (problem_batch_size, sql_len)
                
                # 超长处理
                if current_seq.shape[1] > args.n_position:                                                
                    if (args.use_prompt and eval_prompt_strat == "moving_prompt") or (not args.use_prompt and not args.use_prefix):
                        # 不用 prefix/prompt 或使用 "moving_prompt"，则随着序列增长不断保留尾部序列, 这样 prompt 序列会不断更新
                        current_seq = current_seq[:, trans_dim:]
                    elif args.use_prefix:
                        # 如果设置 prefix, 则维持序列首部的 prefix 不变
                        window_seq_view = torch.roll(current_seq[:,prefix_dim+1:], -trans_dim)  # 将 current_seq 中除 prefix 以外的序列循环左移 prefix_dim
                        current_seq[:,prefix_dim+1:].data.copy_(window_seq_view.data)           # 除 prefix 以外的序列中，首部那个 transition 对应的序列放在尾部
                        current_seq = current_seq[:,:-trans_dim]                                # 把上一步放到尾部的（原首部）transition对应序列去除
                    else:
                        raise NotImplementedError

                    raw_obs = {
                        k: np.concatenate((v[:,1:,:], info['obs'][k][:,None,:]), axis=1) if info['obs'][k].ndim==2 else \
                            np.concatenate((v[:,1:,:], info['obs'][k][:,None,None]), axis=1) for k, v in raw_obs.items()
                    }
                else:
                    raw_obs = {
                        k: np.concatenate((v, info['obs'][k][:,None,:]), axis=1) if info['obs'][k].ndim==2 else \
                            np.concatenate((v, info['obs'][k][:,None,None]), axis=1) for k, v in raw_obs.items()
                    }
                    trans_num = math.ceil((current_seq.shape[1]-prefix_dim)/env.dataset.trans_dim)
                    assert trans_num == list(raw_obs.values())[0].shape[1]
                    obs_idxs = ([0] * prefix_dim + env.dataset.obs_idxs * trans_num)[:current_seq.shape[1]]
                    obs_idxs = torch.tensor([obs_idxs,])                # (1, sql_len)
                    assert current_seq.shape[1] == obs_idxs.shape[1]
                
                

                env_return_AM = 0 if len(episode_return['AM']) == 0 else np.mean(episode_return['AM'])
                env_return_DB1 = 0 if len(episode_return['DB1']) == 0 else np.mean(episode_return['DB1'])
                env_obj = 0 if len(episode_obj) == 0 else np.mean(episode_obj)
                env_obj_std = 0 if len(episode_obj) == 0 else np.std(episode_obj)
                safe_ratio = 0 if episode_slove_cnt == 0 else episode_safe_cnt / episode_slove_cnt
                time_used = 0 if episode_slove_cnt == 0 else (time.time() - time_start) / episode_slove_cnt
                info = {
                    'ret_AM': f'{env_return_AM:.4f}',
                    'ret_DB1': f'{env_return_DB1:.4f}',
                    'obj': f'{env_obj:.2f}',
                    'std': f'{env_obj_std:.2f}',
                    'safe': f'{safe_ratio:.2f}',
                    'time': f'{time_used:.2f}',
                }
                pbar.set_postfix(info)
                pbar.update(done.sum())

            # 得到评估过程轨迹
            #if args.policy_logger:
            #    acts = np.vstack(batch_acts).T                      # (problem_batch_size, max_batch_timestep)
            #    rewards_AM = np.vstack(batch_rewards['AM']).T       # (problem_batch_size, max_batch_timestep)
            #    rewards_DB1 = np.vstack(batch_rewards['DB1']).T     # (problem_batch_size, max_batch_timestep)
            #    timestep_len = np.where(rewards_AM!=0)[1] + 1       # (problem_batch_size, )
            #    assert timestep_len.shape == (args.problem_batch_size, )
            #    
            #    for i in range(args.problem_batch_size):
            #        epi_len = timestep_len[i]                       # epi_len <= max_batch_timestep
            #        epi_spaces = [[space_timestep[0][i],] for space_timestep in batch_value_spaces[:epi_len]]
            #
            #        assert isinstance(env.dataset.obs_type_spec, dict)
            #        assert isinstance(env.dataset.prefix_type_spec, dict)
            #        obs_dict = {}
            #        for k, type_spec in env.dataset.obs_type_spec.items():
            #            obs_dict[k] = np.array([obs[k][i] for obs in batch_obss[:epi_len]]).astype(type_spec)
            #        
            #        episode = {
            #            'prefix': raw_prefix,
            #            'observations': obs_dict,
            #            'actions': acts[i][:epi_len].astype(env.dataset.act_type_spec),
            #            'rewards': {'AM': rewards_AM[i][:epi_len].astype(np.float32), 'DB1': rewards_DB1[i][:epi_len].astype(np.float32)},
            #            'act_value_space': epi_spaces
            #        }
            #        episodes.append(episode)
            
    env_return = {
        'AM': 0 if episode_return['AM'] == [] else np.mean(episode_return['AM']),
        'DB1': 0 if episode_return['DB1'] == [] else np.mean(episode_return['DB1'])        
    } 
    return env_return, env_obj, env_obj_std, safe_ratio, time_used, episodes
'''