import torch.nn as nn
import torch.nn.functional as F
import torch as th
import numpy as np
from utils.transformer import Transformer
from utils.embed import polynomial_embed, binary_embed

class MtTransformerCriticCont(nn.Module):
    def __init__(self, task2input_shape_info, task2decomposer, task2n_agents, is_v, args) -> None:
        super(MtTransformerCriticCont, self).__init__()
        
        self.args = args
        self.task2last_action_shape = {task: task2input_shape_info[task]["last_action_shape"] for task in
                                       task2input_shape_info}
        self.task2decomposer = task2decomposer
        self.task2n_agents = task2n_agents

        self.critic_hidden_dim = args.critic_hidden_dim
        self.is_v = is_v

        self.entity_embed_dim = args.policy_entity_embed_dim
        self.attn_embed_dim = args.attn_embed_dim
        self.action_dim = args.latent_dim

        task_0 = list(task2decomposer.keys())[0]
        obs_own_dim = task2decomposer[task_0].own_obs_dim
        obs_en_dim, obs_al_dim = task2decomposer[task_0].obs_nf_en, task2decomposer[task_0].obs_nf_al
        n_actions_no_attack = task2decomposer[task_0].n_actions_no_attack
        
        has_attack_action = n_actions_no_attack != task2decomposer[task_0].n_actions

        if args.obs_agent_id and args.obs_last_action:
            if has_attack_action:
                ## get wrapped obs_own_dim
                wrapped_obs_own_dim = obs_own_dim + args.id_length + n_actions_no_attack + 1
                ## enemy_obs ought to add attack_action_info
                obs_en_dim += 1
            else:
                wrapped_obs_own_dim = obs_own_dim + args.id_length + n_actions_no_attack
                # wrapped_obs_own_dim = obs_own_dim + args.n_agents + n_actions_no_attack
        else:
            wrapped_obs_own_dim = obs_own_dim

        self.ally_value = nn.Linear(obs_al_dim, self.entity_embed_dim)
        self.enemy_value = nn.Linear(obs_en_dim, self.entity_embed_dim)
        self.own_value = nn.Linear(wrapped_obs_own_dim, self.entity_embed_dim)

        self.action_value = nn.Linear(self.action_dim, self.entity_embed_dim)

        if args.use_role_encoder and not args.only_role_encoding:
            self.encoding_value = nn.Linear(2 * args.encoding_dim, self.entity_embed_dim)
        else:
            self.encoding_value = nn.Linear(args.encoding_dim, self.entity_embed_dim)

        # self.time_embed = nn.Embedding(max_seq_len, self.entity_embed_dim)
        max_ally_num = args.max_ally_num
        self.ally_time_embed = nn.Embedding(max_ally_num, self.entity_embed_dim)

        max_enemy_num = args.max_enemy_num
        self.enemy_time_embed = nn.Embedding(max_enemy_num, self.entity_embed_dim)

        self.transformer = Transformer(self.entity_embed_dim, args.policy_head, args.policy_depth, self.entity_embed_dim)
        if args.use_encoding:
            self.q_skill = nn.Linear(5*self.entity_embed_dim, 1)
            self.v = nn.Linear(4*self.entity_embed_dim, 1) if self.is_v else None
        else:
            self.q_skill = nn.Linear(4*self.entity_embed_dim, 1)
            self.v = nn.Linear(3*self.entity_embed_dim, 1) if self.is_v else None

    def forward(self, inputs, action, task, task_encoding):
        # task_decomposer = self.decomposer
        # last_action_shape = self.args.n_actions
        task_decomposer = self.task2decomposer[task]
        task_n_agents = self.task2n_agents[task]
        last_action_shape = self.task2last_action_shape[task]

        obs_dim = task_decomposer.obs_dim
        need_reshape = False
        if len(inputs.shape) == 4:
            need_reshape = True
            bs, max_t, n_agents, _ = inputs.shape
            if len(task_encoding.shape)==3:
                task_encoding = task_encoding.unsqueeze(1).repeat(1, max_t, 1, 1).reshape(bs* max_t * n_agents, -1)
            else:
                task_encoding = task_encoding.reshape(bs* max_t * n_agents, -1)
            assert n_agents == task_n_agents, f"n_agents {n_agents} does not match task_n_agents {task_n_agents}"
            inputs = inputs.reshape(bs * max_t * n_agents, -1)  # [bs*max_t, n_agents, obs_dim + last_action_shape + agent_id_shape]
            action = action.reshape(bs * max_t * n_agents, -1)

        obs_inputs, last_action_inputs, agent_id_inputs = inputs[:, :obs_dim], \
                                                          inputs[:, obs_dim:obs_dim + last_action_shape], inputs[:,
                                                                                                          obs_dim + last_action_shape:]
        own_obs, enemy_feats, ally_feats = task_decomposer.decompose_obs(
            obs_inputs)  # own_obs: [bs*self.n_agents, own_obs_dim]
        agent_id_inputs = [
            th.as_tensor(binary_embed(i + 1, self.args.id_length, self.args.max_agent), dtype=own_obs.dtype) for i in
            range(task_n_agents)]
        if need_reshape:
            agent_id_inputs = th.stack(agent_id_inputs, dim=0).repeat(bs * max_t, 1).to(own_obs.device)
        else:
            total_bs = inputs.shape[0]
            repeat_num = int(total_bs/task_n_agents)
            agent_id_inputs = th.stack(agent_id_inputs, dim=0).repeat(repeat_num, 1).to(own_obs.device)
        _, attack_action_info, compact_action_states = task_decomposer.decompose_action_info(last_action_inputs)

        if self.args.obs_last_action and self.args.obs_agent_id:
            # if obs_last_action and obs_agent_id, then own_obs should be wrapped
            own_obs = th.cat([own_obs, agent_id_inputs, compact_action_states], dim=-1)
        else:
            own_obs = own_obs
        
        if np.prod(attack_action_info.shape) > 0:
            attack_action_info = attack_action_info.transpose(0, 1).unsqueeze(-1)
            enemy_feats = th.cat([th.stack(enemy_feats, dim=0), attack_action_info], dim=-1)
        else:
            enemy_feats = th.stack(enemy_feats, dim=0)
        ally_feats = th.stack(ally_feats, dim=0)

        # compute key, query and value for attention
        own_hidden = self.own_value(own_obs).unsqueeze(1)
        ally_hidden = self.ally_value(ally_feats).permute(1, 0, 2)
        enemy_hidden = self.enemy_value(enemy_feats).permute(1, 0, 2)
        encoding_hidden = self.encoding_value(task_encoding).unsqueeze(1)

        # ally time embedding
        bs_new, ally_seq_len, _ = ally_hidden.shape
        ally_steps = th.arange(ally_seq_len, device=ally_hidden.device).long()  # (seq_len,)
        ally_step_emb = ally_steps.view(1, ally_seq_len).expand(bs_new, -1)
        ally_step_emb = self.ally_time_embed(ally_step_emb)  # (bs, seq_len, entity_embed_dim)
        ally_hidden = ally_hidden + ally_step_emb
        # enemy time embedding
        bs_new, enemy_seq_len, _ = enemy_hidden.shape
        enemy_steps = th.arange(enemy_seq_len, device=enemy_hidden.device).long()  # (seq_len,)
        enemy_step_emb = enemy_steps.view(1, enemy_seq_len).expand(bs_new, -1)
        enemy_step_emb = self.enemy_time_embed(enemy_step_emb)  # (bs, seq_len, entity_embed_dim)
        enemy_hidden = enemy_hidden + enemy_step_emb

        action_hidden = self.action_value(action)

        if self.args.use_encoding:
            total_hidden = th.cat([encoding_hidden, own_hidden, enemy_hidden, ally_hidden], dim=1)
            
            outputs = self.transformer(total_hidden, None)
            own_length = 1
            enemy_length = enemy_hidden.shape[1]
            ally_length = ally_hidden.shape[1]

            encoding_inputs = outputs[:, 0, :]
            base_action_inputs = outputs[:, 1, :]  # th.cat([outputs[:, 0, :], skill], dim=-1)
            obs_enemy = th.max(outputs[:,2:2+enemy_length,:], dim=1)[0]
            obs_ally = th.max(outputs[:,2+enemy_length:2+enemy_length+ally_length,:], dim=1)[0]
            obs_out = th.cat([encoding_inputs, base_action_inputs, obs_enemy, obs_ally], dim=-1)
        else:
            total_hidden = th.cat([own_hidden, enemy_hidden, ally_hidden], dim=1)
            
            outputs = self.transformer(total_hidden, None)
            own_length = 1
            enemy_length = enemy_hidden.shape[1]
            ally_length = ally_hidden.shape[1]

            base_action_inputs = outputs[:, 0, :]  # th.cat([outputs[:, 0, :], skill], dim=-1)
            obs_enemy = th.max(outputs[:,1:1+enemy_length,:], dim=1)[0]
            obs_ally = th.max(outputs[:,1+enemy_length:1+enemy_length+ally_length,:], dim=1)[0]
            obs_out = th.cat([base_action_inputs, obs_enemy, obs_ally], dim=-1)

        if self.is_v:
            if need_reshape:
                value = self.v(obs_out).reshape(bs, max_t, n_agents, 1)
            else:
                value = self.v(obs_out)
            return value
        else:
            q_input = th.cat([obs_out, action_hidden], dim=-1)
            q = self.q_skill(q_input)
            if need_reshape:
                return q.reshape(bs, max_t, n_agents, 1)
            else:
                return q 




class MtTransformerCritic(nn.Module):
    def __init__(self, task2input_shape_info, task2decomposer, task2n_agents, output_shape, args, is_vqvae=False) -> None:
        super(MtTransformerCritic, self).__init__()
        
        self.args = args
        self.task2last_action_shape = {task: task2input_shape_info[task]["last_action_shape"] for task in
                                       task2input_shape_info}
        self.task2decomposer = task2decomposer
        self.task2n_agents = task2n_agents

        self.critic_hidden_dim = args.critic_hidden_dim
        self.is_v = False
        self.is_vqvae = is_vqvae
        if output_shape == 1:
            self.is_v = True
        elif self.is_vqvae:
            self.output_shape = output_shape

        self.entity_embed_dim = args.policy_entity_embed_dim
        self.attn_embed_dim = args.attn_embed_dim

        task_0 = list(task2decomposer.keys())[0]
        obs_own_dim = task2decomposer[task_0].own_obs_dim
        obs_en_dim, obs_al_dim = task2decomposer[task_0].obs_nf_en, task2decomposer[task_0].obs_nf_al
        n_actions_no_attack = task2decomposer[task_0].n_actions_no_attack
        
        has_attack_action = n_actions_no_attack != task2decomposer[task_0].n_actions

        if args.obs_agent_id and args.obs_last_action:
            if has_attack_action:
                ## get wrapped obs_own_dim
                wrapped_obs_own_dim = obs_own_dim + args.id_length + n_actions_no_attack + 1
                ## enemy_obs ought to add attack_action_info
                obs_en_dim += 1
            else:
                wrapped_obs_own_dim = obs_own_dim + args.id_length + n_actions_no_attack
                # wrapped_obs_own_dim = obs_own_dim + args.n_agents + n_actions_no_attack
        else:
            wrapped_obs_own_dim = obs_own_dim

        self.ally_value = nn.Linear(obs_al_dim, self.entity_embed_dim)
        self.enemy_value = nn.Linear(obs_en_dim, self.entity_embed_dim)
        self.own_value = nn.Linear(wrapped_obs_own_dim, self.entity_embed_dim)

        if args.use_role_encoder and not args.only_role_encoding:
            self.encoding_value = nn.Linear(2 * args.encoding_dim, self.entity_embed_dim)
        else:
            self.encoding_value = nn.Linear(args.encoding_dim, self.entity_embed_dim)

        # self.time_embed = nn.Embedding(max_seq_len, self.entity_embed_dim)
        max_ally_num = args.max_ally_num
        self.ally_time_embed = nn.Embedding(max_ally_num, self.entity_embed_dim)

        max_enemy_num = args.max_enemy_num
        self.enemy_time_embed = nn.Embedding(max_enemy_num, self.entity_embed_dim)

        self.transformer = Transformer(self.entity_embed_dim, args.policy_head, args.policy_depth, self.entity_embed_dim)
        if not self.is_vqvae:
            if args.use_encoding:
                self.q_skill = nn.Linear(4*self.entity_embed_dim, n_actions_no_attack)
                self.v = nn.Linear(4*self.entity_embed_dim, 1) if self.is_v else None
                self.attack_skill = nn.Linear(2*self.entity_embed_dim, 1)
            else:
                self.q_skill = nn.Linear(3*self.entity_embed_dim, n_actions_no_attack)
                self.v = nn.Linear(3*self.entity_embed_dim, 1) if self.is_v else None
                self.attack_skill = nn.Linear(self.entity_embed_dim, 1)
        else:
            if args.use_encoding:
                self.q_skill = nn.Linear(4*self.entity_embed_dim, self.output_shape)
                self.v = nn.Linear(4*self.entity_embed_dim, 1) if self.is_v else None
            else:
                self.q_skill = nn.Linear(3*self.entity_embed_dim, self.output_shape)
                self.v = nn.Linear(3*self.entity_embed_dim, 1) if self.is_v else None

    
    def forward(self, inputs, task, task_encoding):
        # task_decomposer = self.decomposer
        # last_action_shape = self.args.n_actions
        task_decomposer = self.task2decomposer[task]
        task_n_agents = self.task2n_agents[task]
        last_action_shape = self.task2last_action_shape[task]

        obs_dim = task_decomposer.obs_dim
        bs, max_t, n_agents, _ = inputs.shape
        if len(task_encoding.shape)==3:
            task_encoding = task_encoding.unsqueeze(1).repeat(1, max_t, 1, 1).reshape(bs* max_t * n_agents, -1)
        else:
            task_encoding = task_encoding.reshape(bs* max_t * n_agents, -1)
        assert n_agents == task_n_agents, f"n_agents {n_agents} does not match task_n_agents {task_n_agents}"
        inputs = inputs.reshape(bs * max_t * n_agents, -1)  # [bs*max_t, n_agents, obs_dim + last_action_shape + agent_id_shape]
        obs_inputs, last_action_inputs, agent_id_inputs = inputs[:, :obs_dim], \
                                                          inputs[:, obs_dim:obs_dim + last_action_shape], inputs[:,
                                                                                                          obs_dim + last_action_shape:]
        own_obs, enemy_feats, ally_feats = task_decomposer.decompose_obs(
            obs_inputs)  # own_obs: [bs*self.n_agents, own_obs_dim]
        agent_id_inputs = [
            th.as_tensor(binary_embed(i + 1, self.args.id_length, self.args.max_agent), dtype=own_obs.dtype) for i in
            range(task_n_agents)]
        agent_id_inputs = th.stack(agent_id_inputs, dim=0).repeat(bs * max_t, 1).to(own_obs.device)
        _, attack_action_info, compact_action_states = task_decomposer.decompose_action_info(last_action_inputs)

        if self.args.obs_last_action and self.args.obs_agent_id:
            # if obs_last_action and obs_agent_id, then own_obs should be wrapped
            own_obs = th.cat([own_obs, agent_id_inputs, compact_action_states], dim=-1)
        else:
            own_obs = own_obs
        
        if np.prod(attack_action_info.shape) > 0:
            attack_action_info = attack_action_info.transpose(0, 1).unsqueeze(-1)
            enemy_feats = th.cat([th.stack(enemy_feats, dim=0), attack_action_info], dim=-1)
        else:
            enemy_feats = th.stack(enemy_feats, dim=0)
        ally_feats = th.stack(ally_feats, dim=0)

        # compute key, query and value for attention
        own_hidden = self.own_value(own_obs).unsqueeze(1)
        ally_hidden = self.ally_value(ally_feats).permute(1, 0, 2)
        enemy_hidden = self.enemy_value(enemy_feats).permute(1, 0, 2)
        encoding_hidden = self.encoding_value(task_encoding).unsqueeze(1)

        # ally time embedding
        bs_new, ally_seq_len, _ = ally_hidden.shape
        ally_steps = th.arange(ally_seq_len, device=ally_hidden.device).long()  # (seq_len,)
        ally_step_emb = ally_steps.view(1, ally_seq_len).expand(bs_new, -1)
        ally_step_emb = self.ally_time_embed(ally_step_emb)  # (bs, seq_len, entity_embed_dim)
        ally_hidden = ally_hidden + ally_step_emb
        # enemy time embedding
        bs_new, enemy_seq_len, _ = enemy_hidden.shape
        enemy_steps = th.arange(enemy_seq_len, device=enemy_hidden.device).long()  # (seq_len,)
        enemy_step_emb = enemy_steps.view(1, enemy_seq_len).expand(bs_new, -1)
        enemy_step_emb = self.enemy_time_embed(enemy_step_emb)  # (bs, seq_len, entity_embed_dim)
        enemy_hidden = enemy_hidden + enemy_step_emb

        if self.args.use_encoding:
            total_hidden = th.cat([encoding_hidden, own_hidden, enemy_hidden, ally_hidden], dim=1)
            
            outputs = self.transformer(total_hidden, None)
            own_length = 1
            enemy_length = enemy_hidden.shape[1]
            ally_length = ally_hidden.shape[1]

            encoding_inputs = outputs[:, 0, :]
            base_action_inputs = outputs[:, 1, :]  # th.cat([outputs[:, 0, :], skill], dim=-1)
            obs_enemy = th.max(outputs[:,2:2+enemy_length,:], dim=1)[0]
            obs_ally = th.max(outputs[:,2+enemy_length:2+enemy_length+ally_length,:], dim=1)[0]
            obs_out = th.cat([encoding_inputs, base_action_inputs, obs_enemy, obs_ally], dim=-1)
        else:
            total_hidden = th.cat([own_hidden, enemy_hidden, ally_hidden], dim=1)
            
            outputs = self.transformer(total_hidden, None)
            own_length = 1
            enemy_length = enemy_hidden.shape[1]
            ally_length = ally_hidden.shape[1]

            base_action_inputs = outputs[:, 0, :]  # th.cat([outputs[:, 0, :], skill], dim=-1)
            obs_enemy = th.max(outputs[:,1:1+enemy_length,:], dim=1)[0]
            obs_ally = th.max(outputs[:,1+enemy_length:1+enemy_length+ally_length,:], dim=1)[0]
            obs_out = th.cat([base_action_inputs, obs_enemy, obs_ally], dim=-1)

        if self.is_v:
            value = self.v(obs_out).reshape(bs, max_t, n_agents, 1)
            return value
        else:
            if not self.is_vqvae:
                q_base = self.q_skill(obs_out)
                if task_decomposer.n_actions_no_attack == task_decomposer.n_actions:
                    q = q_base
                else:
                    q_attack_list = []
                    for i in range(enemy_feats.size(0)):
                        if self.args.use_encoding:
                            attack_action_inputs = outputs[:, 2+i, :]
                            attack_action_inputs = th.cat([encoding_inputs, attack_action_inputs], dim=-1)
                        else:
                            attack_action_inputs = outputs[:, 1+i, :]
                        # q_enemy = self.q_skill(attack_action_inputs)
                        # q_enemy_mean = th.mean(q_enemy, 1, True)
                        q_enemy = self.attack_skill(attack_action_inputs)
                        q_attack_list.append(q_enemy)
                    q_attack = th.stack(q_attack_list, dim=1).squeeze()
                    q = th.cat([q_base, q_attack], dim=-1)
                return q.reshape(bs, max_t, n_agents, task_decomposer.n_actions)
            else:
                q = self.q_skill(obs_out)
                return q.reshape(bs, max_t, n_agents, self.output_shape)


class TransformerCritic(nn.Module):
    def __init__(self, input_shape, output_shape, decomposer, args) -> None:
        super(TransformerCritic, self).__init__()
        
        self.args = args
        self.decomposer = decomposer
        self.critic_hidden_dim = args.critic_hidden_dim
        self.is_v = False
        if output_shape == 1:
            self.is_v = True

        self.entity_embed_dim = args.entity_embed_dim
        self.attn_embed_dim = args.attn_embed_dim

        obs_own_dim = decomposer.own_obs_dim
        obs_en_dim, obs_al_dim = decomposer.obs_nf_en, decomposer.obs_nf_al
        n_actions_no_attack = decomposer.n_actions_no_attack

        self.n_agents = args.n_agents
        
        has_attack_action = n_actions_no_attack != decomposer.n_actions

        if args.obs_agent_id and args.obs_last_action:
            if has_attack_action:
                ## get wrapped obs_own_dim
                wrapped_obs_own_dim = obs_own_dim + args.id_length + n_actions_no_attack + 1
                ## enemy_obs ought to add attack_action_info
                obs_en_dim += 1
            else:
                wrapped_obs_own_dim = obs_own_dim + args.id_length + n_actions_no_attack
                # wrapped_obs_own_dim = obs_own_dim + args.n_agents + n_actions_no_attack
        else:
            wrapped_obs_own_dim = obs_own_dim

        self.ally_value = nn.Linear(obs_al_dim, self.entity_embed_dim)
        self.enemy_value = nn.Linear(obs_en_dim, self.entity_embed_dim)
        self.own_value = nn.Linear(wrapped_obs_own_dim, self.entity_embed_dim)

        max_seq_len = args.n_agents + decomposer.n_enemies

        # self.time_embed = nn.Embedding(max_seq_len, self.entity_embed_dim)
        max_ally_num = args.max_ally_num if hasattr(args, 'max_ally_num') else args.n_agents
        self.ally_time_embed = nn.Embedding(max_ally_num, self.entity_embed_dim)

        max_enemy_num = args.max_enemy_num if hasattr(args, 'max_enemy_num') else decomposer.n_enemies
        self.enemy_time_embed = nn.Embedding(max_enemy_num, self.entity_embed_dim)

        self.transformer = Transformer(self.entity_embed_dim, args.head, args.depth, self.entity_embed_dim)

        self.q_skill = nn.Linear(self.entity_embed_dim, n_actions_no_attack)
        self.v = nn.Linear(self.entity_embed_dim, 1) if self.is_v else None

    
    def forward(self, inputs):
        task_decomposer = self.decomposer
        last_action_shape = self.args.n_actions
        obs_dim = task_decomposer.obs_dim
        bs, max_t, n_agents, _ = inputs.shape
        inputs = inputs.reshape(bs * max_t * n_agents, -1)  # [bs*max_t, n_agents, obs_dim + last_action_shape + agent_id_shape]
        obs_inputs, last_action_inputs, agent_id_inputs = inputs[:, :obs_dim], \
                                                          inputs[:, obs_dim:obs_dim + last_action_shape], inputs[:,
                                                                                                          obs_dim + last_action_shape:]
        own_obs, enemy_feats, ally_feats = task_decomposer.decompose_obs(
            obs_inputs)  # own_obs: [bs*self.n_agents, own_obs_dim]
        agent_id_inputs = [
            th.as_tensor(binary_embed(i + 1, self.args.id_length, self.args.max_agent), dtype=own_obs.dtype) for i in
            range(self.n_agents)]
        agent_id_inputs = th.stack(agent_id_inputs, dim=0).repeat(bs * max_t, 1).to(own_obs.device)
        _, attack_action_info, compact_action_states = task_decomposer.decompose_action_info(last_action_inputs)

        if self.args.obs_last_action and self.args.obs_agent_id:
            # if obs_last_action and obs_agent_id, then own_obs should be wrapped
            own_obs = th.cat([own_obs, agent_id_inputs, compact_action_states], dim=-1)
        else:
            own_obs = own_obs
        
        if np.prod(attack_action_info.shape) > 0:
            attack_action_info = attack_action_info.transpose(0, 1).unsqueeze(-1)
            enemy_feats = th.cat([th.stack(enemy_feats, dim=0), attack_action_info], dim=-1)
        else:
            enemy_feats = th.stack(enemy_feats, dim=0)
        ally_feats = th.stack(ally_feats, dim=0)

        # compute key, query and value for attention
        own_hidden = self.own_value(own_obs).unsqueeze(1)
        ally_hidden = self.ally_value(ally_feats).permute(1, 0, 2)
        enemy_hidden = self.enemy_value(enemy_feats).permute(1, 0, 2)

        # ally time embedding
        bs_new, ally_seq_len, _ = ally_hidden.shape
        ally_steps = th.arange(ally_seq_len, device=ally_hidden.device).long()  # (seq_len,)
        ally_step_emb = ally_steps.view(1, ally_seq_len).expand(bs_new, -1)
        ally_step_emb = self.ally_time_embed(ally_step_emb)  # (bs, seq_len, entity_embed_dim)
        ally_hidden = ally_hidden + ally_step_emb
        # enemy time embedding
        bs_new, enemy_seq_len, _ = enemy_hidden.shape
        enemy_steps = th.arange(enemy_seq_len, device=enemy_hidden.device).long()  # (seq_len,)
        enemy_step_emb = enemy_steps.view(1, enemy_seq_len).expand(bs_new, -1)
        enemy_step_emb = self.enemy_time_embed(enemy_step_emb)  # (bs, seq_len, entity_embed_dim)
        enemy_hidden = enemy_hidden + enemy_step_emb

        total_hidden = th.cat([own_hidden, enemy_hidden, ally_hidden], dim=1)
        
        outputs = self.transformer(total_hidden, None)
        base_action_inputs = outputs[:, 0, :]  # th.cat([outputs[:, 0, :], skill], dim=-1)
        if self.is_v:
            value = self.v(base_action_inputs).reshape(bs, max_t, n_agents, 1)
            return value
        else:
            q_base = self.q_skill(base_action_inputs)
            if task_decomposer.n_actions_no_attack == task_decomposer.n_actions:
                q = q_base
            else:
                q_attack_list = []
                for i in range(enemy_feats.size(0)):
                    attack_action_inputs = outputs[:, 1+i, :]
                    q_enemy = self.q_skill(attack_action_inputs)
                    q_enemy_mean = th.mean(q_enemy, 1, True)
                    q_attack_list.append(q_enemy_mean)
                q_attack = th.stack(q_attack_list, dim=1).squeeze()
                q = th.cat([q_base, q_attack], dim=-1)
            return q.reshape(bs, max_t, n_agents, task_decomposer.n_actions)