import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(base_path)

import torch
import numpy as np
from torch import nn
from typing import List, Optional, Tuple, Union
from transformers.models.llama import LlamaConfig
from transformers import AutoConfig, AutoModel, LlamaModel
from transformers.trainer_pt_utils import get_parameter_names
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from dataloader.code.input_specs import RLTaskInput
import torch.nn.functional as F
'''
class LlamaBackbone(LlamaModel):
    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        self.prefix_spliter = config.prefix_spliter
        self.action_spliter = config.action_spliter
        self.use_dynamic_prefix = config.use_dynamic_prefix
        self.local_pos_encoding = nn.Embedding(config.vocab_size, config.hidden_size)       # 局部位置编码
        self.post_init()

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        loss_mask: Optional[torch.Tensor] = None,
        prefix_mask_list: List = None,
        local_position_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        compute_loss = True
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        assert input_ids is not None and inputs_embeds is None, "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
        assert not (self.gradient_checkpointing and self.training and use_cache), "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # token embedding
        local_positional_emb = self.local_pos_encoding(local_position_ids)  # (batch_size, seq_len, n_embed)
        inputs_embeds = self.embed_tokens(input_ids)                        # (batch_size, seq_len, n_embed)
        inputs_embeds = inputs_embeds + local_positional_emb                # (batch_size, seq_len, n_embed)

        # kv cache & position cache (for evaluation)
        past_seen_tokens = 0
        if use_cache:  # kept for BC (cache positions)
            if not isinstance(past_key_values, StaticCache):
                past_key_values = DynamicCache.from_legacy_cache(past_key_values)
                past_seen_tokens = past_key_values.get_seq_length()
        
        assert cache_position is None, 'cache_position 应该仅在推断阶段使用 kvcache 时才会传入'
        if cache_position is None:
            if isinstance(past_key_values, StaticCache):
                raise ValueError("cache_position is a required argument when using StaticCache.")
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        # 绝对位置信息
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)      # (1, seq_len)

        # attention mask
        causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
        if self.prefix_spliter is not None:
            input_ids_temp = np.array(input_ids.detach().clone().cpu())                                 # (batch_size, seq_len)
            loss_mask_temp = None if loss_mask is None else np.array(loss_mask.detach().clone().cpu())  # (batch_size, seq_len)

            for i in range(len(causal_mask)):
                # prefix段使用双向注意力
                prefix_spliter_idx = torch.where(input_ids[i]==self.prefix_spliter)[0]
                assert len(prefix_spliter_idx) == 1
                causal_mask[i][0][:,:prefix_spliter_idx] = 0
                
                # 动态prefix
                if self.use_dynamic_prefix:
                    prefix_mask_of_actions = prefix_mask_list[i].squeeze()      # (act_num, prefix_dim)
                    if prefix_mask_of_actions.ndim == 1:
                        prefix_mask_of_actions = prefix_mask_of_actions[None,:]
                    
                    #action_predict_idxs = np.where(input_ids_temp[i] == self.action_spliter)[0]
                    action_predict_idxs = np.where(loss_mask_temp[i])[0]
                    if compute_loss:
                        assert len(action_predict_idxs) == len(prefix_mask_of_actions)
                        for j, act_predict_idx in enumerate(action_predict_idxs):
                            prefix_mask = prefix_mask_of_actions[j]
                            causal_mask[i][0][act_predict_idx][:prefix_spliter_idx] += (prefix_mask * torch.finfo(torch.float32).min).to(causal_mask.device)
                    else:
                        assert len(prefix_mask_of_actions) == 1
                        act_predict_idx = action_predict_idxs[-1]
                        prefix_mask = prefix_mask_of_actions[0]
                        causal_mask[i][0][act_predict_idx][:prefix_spliter_idx] += (prefix_mask * torch.finfo(torch.float32).min).to(causal_mask.device)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None
        hidden_states = inputs_embeds
        for decoder_layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values, # past_key_values:DynamicCache 会逐层缓存 key & value 
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)        # (batch_size, seq_len, hidden_size)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = None
        if use_cache:
            next_cache = (
                next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
            )
        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

'''
class LlamaBackbone(LlamaModel):
    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        self.prefix_spliter = config.prefix_spliter
        self.action_spliter = config.action_spliter
        self.use_dynamic_prefix = config.use_dynamic_prefix
        self.local_pos_encoding = nn.Embedding(config.vocab_size, config.hidden_size)       # 局部位置编码
        self.post_init()

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        loss_mask: Optional[torch.Tensor] = None,
        prefix_mask_list: List = None,
        local_position_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        compute_loss = True
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        assert input_ids is not None and inputs_embeds is None, "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
        assert not (self.gradient_checkpointing and self.training and use_cache), "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # token embedding
        local_positional_emb = self.local_pos_encoding(local_position_ids)  # (batch_size, seq_len, n_embed)
        inputs_embeds = self.embed_tokens(input_ids)                        # (batch_size, seq_len, n_embed)
        inputs_embeds = inputs_embeds + local_positional_emb                # (batch_size, seq_len, n_embed)

        # kv cache & position cache (for evaluation)
        past_seen_tokens = 0
        if use_cache:  # kept for BC (cache positions)
            if not isinstance(past_key_values, StaticCache):
                past_key_values = DynamicCache.from_legacy_cache(past_key_values)
                past_seen_tokens = past_key_values.get_seq_length()
        
        assert cache_position is None, 'cache_position 应该仅在推断阶段使用 kvcache 时才会传入'
        if cache_position is None:
            if isinstance(past_key_values, StaticCache):
                raise ValueError("cache_position is a required argument when using StaticCache.")
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        # 绝对位置信息
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)      # (1, seq_len)

        # attention mask
        causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
        if self.prefix_spliter is not None:
            input_ids_temp = np.array(input_ids.detach().clone().cpu())
            for i in range(len(causal_mask)):
                # prefix段使用双向注意力
                prefix_spliter_idx = torch.where(input_ids[i]==self.prefix_spliter)[0]
                assert len(prefix_spliter_idx) == 1
                causal_mask[i][0][:,:prefix_spliter_idx] = 0
                
                # 动态prefix
                if self.use_dynamic_prefix:
                    prefix_mask_of_timestep = prefix_mask_list[i].squeeze()     # (timestep, prefix_dim)
                    if prefix_mask_of_timestep.ndim == 1:
                        prefix_mask_of_timestep = prefix_mask_of_timestep[None,:]
                    
                    action_predict_idxs = np.where(input_ids_temp[i] == self.action_spliter)[0]
                    if compute_loss:
                        assert len(action_predict_idxs) == len(prefix_mask_of_timestep)
                        for t, act_predict_idx in enumerate(action_predict_idxs):
                            prefix_mask = prefix_mask_of_timestep[t]
                            causal_mask[i][0][act_predict_idx][:prefix_spliter_idx] += (prefix_mask * torch.finfo(torch.float32).min).to(causal_mask.device)
                    else:
                        assert len(prefix_mask_of_timestep) == 1
                        act_predict_idx = action_predict_idxs[-1]
                        prefix_mask = prefix_mask_of_timestep[0]
                        causal_mask[i][0][act_predict_idx][:prefix_spliter_idx] += (prefix_mask * torch.finfo(torch.float32).min).to(causal_mask.device)


        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None
        hidden_states = inputs_embeds
        for decoder_layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values, # past_key_values:DynamicCache 会逐层缓存 key & value 
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)        # (batch_size, seq_len, hidden_size)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = None
        if use_cache:
            next_cache = (
                next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
            )
        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )


class TrajLlama(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.weight_decay_incr_style == "constant"         # 现在只考虑固定的权重衰减系数
        self.config = config
        self.n_embed = config.n_embed                               # token 嵌入维度
        self.n_position = config.n_position                         # 模型 token 位置数量（模型支持的上下文长度）
        self.n_layer = config.n_layer                               # transformer block 数量
        self.n_q_head = config.n_q_head                             # 注意力头总数
        self.n_kv_head = config.n_kv_head                           # 注意力头分组数，每组内所有头共用 key 和 value，即 GQA 机制
        self.d_head = self.n_embed // self.n_q_head                 # 每个 attention head 将 token 嵌入向量投影到的维度                      
        self.d_inner = (int(self.n_embed * 8/3 / 64) + 1) * 64      # SwiGLU 中间层尺寸，取 8/3 倍再按 64 向上取整（用 8/3 而非 2 倍是因为 SwiGLU 包含三个参数矩阵） 
        self.total_vocab_size = config.num_discrete_values + config.num_continuous_bin + len(config.special_tokens)
        
        # build llama backbone
        action_spliter = config.special_tokens['<|>']
        prefix_spliter = None if not config.use_prefix else config.special_tokens['<X>']
        llama_config = AutoConfig.for_model(
            model_type="llama",
            vocab_size = self.total_vocab_size,
            hidden_size = self.n_embed,
            intermediate_size = self.d_inner,
            num_hidden_layers = self.n_layer,
            num_attention_heads = self.n_q_head,
            num_key_value_heads = self.n_kv_head,
            max_position_embeddings = self.n_position,
            rms_norm_eps = self.config.rms_norm_eps,
            pad_token_id=None,
            bos_token_id=None,
            eos_token_id=None,
            tie_word_embeddings=self.config.share_input_output_embedding,
            attention_bias=False,
            attention_dropout=self.config.dropattn,
            action_spliter=action_spliter,
            prefix_spliter=prefix_spliter,
            use_dynamic_prefix=self.config.use_dynamic_prefix,
        )
        self.model = LlamaBackbone._from_config(llama_config, torch_dtype=torch.float32)
        self.lm_head = nn.Linear(self.n_embed, self.total_vocab_size, bias=False)
        self.lm_head.weight.data.normal_(mean=0.0, std=llama_config.initializer_range)

        # 对于指定的字段，直接MLP嵌入到 self.n_embed（和 tokenize + nn.Embedding 地位相同）
        mlp_embedding_hidden_size = int(self.n_embed/3)
        self.mlp_embedding_dict = nn.ModuleDict({
            item_name: nn.Sequential(
                nn.Linear(embedding_info['dim'], mlp_embedding_hidden_size),
                nn.ReLU(),
                nn.Linear(mlp_embedding_hidden_size, self.n_embed), 
            ) for item_name, embedding_info in config.mlp_emb_items.items()
        })
        item_names = []
        for emb_name, emb_info in config.mlp_emb_items.items():
            item_names.extend([(emb_name, item_name) for item_name in emb_info['item_name']])
        self.obs_idx_2_name = {i+1:names for i,names in enumerate(item_names)}                  # {obs_idx: (emb_name, item_name)}

    def configure_optimizers(self):
        """
        utilize weight decay method to avoid overfitting
        We are separating out all parameters of the model into two buckets: those that will experience
        weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
        We are then returning the PyTorch optimizer object.
        """
        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay_parameters = get_parameter_names(self, ALL_LAYERNORM_LAYERS)
        decay_parameters = [name for name in decay_parameters if "bias" not in name]
        decay = set([n for n, p in self.named_parameters() if (n in decay_parameters and p.requires_grad)])
        no_decay = set([n for n, p in self.named_parameters() if (n not in decay_parameters and p.requires_grad)])

        # validate that we considered every parameter
        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                    % (str(param_dict.keys() - union_params), )

        # create the pytorch optimizer object
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": self.config.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        optimizer = torch.optim.AdamW(
            optim_groups, 
            lr=self.config.lr_begin, 
            betas=(self.config.adam_beta1, self.config.adam_beta2), 
            eps=self.config.adam_eps
        )
        return optimizer
        
    def forward(self, tasks_input:RLTaskInput, compute_loss:bool=True, mems=None, batch_dataset_name:List=None, batch_raw_obs:Union[List, dict]=None):
        label = tasks_input.label                   # (batch_size, seq_len)
        obs_idxs = tasks_input.obs_idxs             # (batch_size, seq_len)
        loss_mask = tasks_input.loss_mask           # (batch_size, seq_len)
        
        '''
        (   
            embedding,                              # (batch_size, seq_len, n_embed), 训练时seq_len=1024
            loss_mask,                              # (batch_size, seq_len)，测试时为 None
            attn_mask,                              # None
            label                                   # (batch_size, seq_len)，测试时为 None
        ) = self._forward_rl(tasks_input, compute_loss, batch_raw_obs)
        '''
        
        outputs = self.model(
            input_ids = tasks_input.tensor_seq,             # (batch_size, seq_len)
            attention_mask = tasks_input.attention_mask,    # (batch_size, seq_len)
            loss_mask = loss_mask,                          # (batch_size, seq_len)
            prefix_mask_list = tasks_input.prefix_mask,     # list of (timestep, prefix_dim), len=batch_size
            local_position_ids = tasks_input.position_id,   # (batch_size, seq_len)
            position_ids = None,
            inputs_embeds = None,
            past_key_values = None,
            use_cache = False,                              # 训练阶段不需要kvcache，推理阶段当前未适配
            return_dict = True,
            output_attentions = False,
            output_hidden_states = False,
            cache_position = None,
            compute_loss = compute_loss
        )
        hidden_states = outputs.last_hidden_state           # (batch_size, seq_len, n_embed)
        
        # 最后的线性层将 hidden_states 维度由 n_embed 调整为输出尺寸 total_vocab_size
        if self.config.share_input_output_embedding:
            lm_logits = F.linear(hidden_states, self.model.embed_tokens.weight)
            assert lm_logits.shape[:-1] == hidden_states.shape[:-1]
        else:
            lm_logits = self.lm_head(hidden_states)         # (batch_size, klen, total_vocab_size)

        # compute loss when training
        loss_ave = None
        loss_datasets = {dataset_name: [] for dataset_name in self.config.eval_dataset_names}
        if compute_loss:
            assert loss_mask is not None and label is not None
            # Flatten the tokens
            loss_mask = loss_mask.detach().float()                                      # (batch_size, klen)
            loss_fct = nn.CrossEntropyLoss(reduction="none")
            act_lm_logits = lm_logits[loss_mask==1]                                     # (act_num, vocab_size)
            act_label = label[loss_mask==1]                                             # (act_num, )
            loss = loss_fct(act_lm_logits.view(-1, self.total_vocab_size), act_label)   # (act_num, )
                    
            # 所有数据集的平均损失
            loss_ave = loss.mean()

            # 分别各个数据集上的损失，它们的加权平均值应当和 loss_ave 相同
            loss = loss.detach()
            loss_num = (loss_mask==1).sum(axis=1)
            loss_end = torch.cumsum(loss_num, dim=0).tolist()
            loss_start = [0] + loss_end[:-1]
            for idx, dataset_name in enumerate(batch_dataset_name):
                start, end = loss_start[idx], loss_end[idx]
                loss_datasets[dataset_name].extend(loss[start:end].tolist())
            
            loss_num = len(loss)
            assert sum([len(v) for v in loss_datasets.values()]) == loss_num
            weighted_loss = {k:0 if len(v) == 0 else len(v)/loss_num*np.mean(v) for k,v in loss_datasets.items()}
            assert abs(loss_ave.item() - np.sum(list(weighted_loss.values()))) < 1e-3, f'unequal loss: {loss_ave.item()} != {np.sum(list(weighted_loss.values()))}'
            loss_datasets = {k:0 if len(v) == 0 else np.mean(v) for k,v in loss_datasets.items()}

        new_mems = None
        return lm_logits, loss_ave, loss_datasets, new_mems
                            
    def _forward_rl(self, rl_input:RLTaskInput, compute_loss:bool=True, batch_raw_obs:Union[List, dict]=None):
        batch_input_tensor = rl_input.tensor_seq            # (batch_size, seq_len)
        batch_position_id = rl_input.position_id            # (batch_size, seq_len)
        batch_label = rl_input.label
        batch_loss_mask = rl_input.loss_mask
        batch_attention_mask = rl_input.attention_mask
        batch_obs_idxs = rl_input.obs_idxs
        device = batch_input_tensor.device

        # token_emb 保存来自各个嵌入模块的入结果
        batch_size, seq_len = batch_input_tensor.shape[0], batch_input_tensor.shape[1]
        token_emb = torch.zeros((batch_size, seq_len, self.n_embed)).to(device)     # (batch_size, seq_len, n_embed)
        
        # DDP env 批量评估过程，保证所有 batch_raw_obs 来自相同环境，可以批量处理
        if isinstance(batch_raw_obs, dict):
            assert not compute_loss
            assert list(batch_obs_idxs.shape) == [1, batch_input_tensor.shape[1]]
            obs_idxs = batch_obs_idxs[0]

            # 对于不使用MLP嵌入的 token，使用 nn.Embedding 进行嵌入
            token_emb[:,obs_idxs==0] = self.token_embedding(batch_input_tensor[:,obs_idxs==0])
            
            # 对于使用MLP嵌入的 token holder，找到对应的MLP嵌入模块，对 obs_idxs 指出的相关原始输入进行MLP嵌入
            for obs_idx, (emb_name, item_name) in self.obs_idx_2_name.items():  
                if item_name in batch_raw_obs:
                    mlp_module = self.mlp_embedding_dict[emb_name]                          # 字段对应的MLP嵌入模块
                    input_dim = self.config.mlp_emb_items[emb_name]['dim']                  # MLP嵌入的原始观测输入维度
                    raw_obs = batch_raw_obs[item_name].reshape((batch_size, -1, input_dim)).to(device)# 根据输入维度调整原始观测形状 (obs_item_num, obs_input_dim)
                    item_token_len = (obs_idxs==obs_idx).sum().item()                       # 该观测字段对应的 token 数量 (由于构造 token 序列时可能按上下文长度截断，可能 obs_item_num > item_token_len)
                    assert (token_emb[:,obs_idxs==obs_idx]).sum().item() == 0
                    token_emb[:,obs_idxs==obs_idx] = mlp_module(raw_obs)[:,:item_token_len]
        # 训练过程，batch_raw_obs 可能来自不同环境导致格式不同，无法批量处理
        else:
            assert compute_loss
            assert len(batch_obs_idxs) == len(batch_raw_obs)
            for i, obs in enumerate(batch_raw_obs):         # 遍历 batch_data
                obs_idxs = batch_obs_idxs[i]                # obs_idxs 记录 token 序列中各 token 所属的观测字段索引
                
                # 对于不使用MLP嵌入的 token，使用 nn.Embedding 进行嵌入
                token_emb[i][obs_idxs==0] = self.token_embedding(batch_input_tensor[i][obs_idxs==0])

                # 对于使用MLP嵌入的 token holder，找到对应的MLP嵌入模块，对 obs_idxs 指出的相关原始输入进行MLP嵌入
                for obs_idx, (emb_name, item_name) in self.obs_idx_2_name.items():  
                    if item_name in obs:
                        mlp_module = self.mlp_embedding_dict[emb_name]              # 字段对应的MLP嵌入模块
                        input_dim = self.config.mlp_emb_items[emb_name]['dim']      # MLP嵌入的原始观测输入维度
                        raw_obs = obs[item_name].reshape((-1, input_dim)).to(device)# 根据输入维度调整原始观测形状 (obs_item_num, obs_input_dim)
                        item_token_len = (obs_idxs==obs_idx).sum().item()           # 该观测字段对应的 token 数量 (由于构造 token 序列时可能按上下文长度截断，可能 obs_item_num > item_token_len)
                        assert (token_emb[i][obs_idxs==obs_idx]).sum().item() == 0
                        token_emb[i][obs_idxs==obs_idx] = mlp_module(raw_obs)[:item_token_len]

        local_positional_emb = self.local_pos_encoding(batch_position_id)   # 局部位置编码
        hidden_states = token_emb + local_positional_emb                    # [batch_size, seq_len, n_embed]

        return hidden_states, batch_loss_mask, batch_attention_mask, batch_label
