import torch.nn as nn
import torch.nn.functional as F
import torch as th
import numpy as np
from utils.transformer import Transformer

class TransformerTransitionEncoder(nn.Module):
    def __init__(self, task2decomposer, args, is_club=False) -> None:
        super(TransformerTransitionEncoder, self).__init__()
        
        self.args = args
        self.task2decomposer = task2decomposer
        self.is_club = is_club

        self.entity_embed_dim = args.entity_embed_dim
        self.attn_embed_dim = args.attn_embed_dim
        self.encoding_dim = args.transition_encoding_dim

        task_0 = list(task2decomposer.keys())[0]
        obs_own_dim = task2decomposer[task_0].own_obs_dim
        obs_en_dim, obs_al_dim = task2decomposer[task_0].obs_nf_en, task2decomposer[task_0].obs_nf_al
        n_actions_no_attack = task2decomposer[task_0].n_actions_no_attack
        
        has_attack_action = n_actions_no_attack != task2decomposer[task_0].n_actions
        self.direct_enemy_action_embedding = False
        if has_attack_action:
            self.direct_enemy_action_embedding = args.direct_enemy_action_embedding

        self.ally_value = nn.Linear(obs_al_dim, self.entity_embed_dim)
        if not self.direct_enemy_action_embedding and has_attack_action:
            self.enemy_value = nn.Linear(obs_en_dim + 1, self.entity_embed_dim)
        else:
            self.enemy_value = nn.Linear(obs_en_dim, self.entity_embed_dim)
        self.own_value = nn.Linear(obs_own_dim, self.entity_embed_dim)

        self.next_ally_value = nn.Linear(obs_al_dim, self.entity_embed_dim)
        self.next_enemy_value = nn.Linear(obs_en_dim, self.entity_embed_dim)
        self.next_own_value = nn.Linear(obs_own_dim, self.entity_embed_dim)

        self.no_attack_action_value = nn.Linear(n_actions_no_attack, self.entity_embed_dim)
        self.attack_action_value = nn.Linear(1, self.entity_embed_dim)

        self.reward_value = nn.Linear(1, self.entity_embed_dim)

        # self.time_embed = nn.Embedding(max_seq_len, self.entity_embed_dim)
        max_ally_num = args.max_ally_num
        self.ally_time_embed = nn.Embedding(max_ally_num, self.entity_embed_dim)
        self.next_ally_time_embed = nn.Embedding(max_ally_num, self.entity_embed_dim)

        max_enemy_num = args.max_enemy_num
        self.enemy_time_embed = nn.Embedding(max_enemy_num, self.entity_embed_dim)
        self.next_enemy_time_embed = nn.Embedding(max_enemy_num, self.entity_embed_dim)

        self.obs_transformer = Transformer(self.entity_embed_dim, args.head, args.depth, self.entity_embed_dim)
        self.next_obs_transformer = Transformer(self.entity_embed_dim, args.head, args.depth, self.entity_embed_dim)

        self.hidden_dim = args.encoder_hidden_dim
        action_dim = task2decomposer[task_0].n_actions_no_attack
        if self.is_club:
            input_dim = 3 * self.entity_embed_dim + action_dim
        else:
            input_dim = 6 * self.entity_embed_dim + action_dim + 1
        self.fc = nn.Linear(input_dim, self.hidden_dim)
        self.mid = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.out = nn.Linear(self.hidden_dim, self.encoding_dim)

        if self.args.mlp_encoding:
            obs_dim = task2decomposer[task_0].obs_dim
            action_dim = task2decomposer[task_0].n_actions_no_attack
            self.hidden_dim = args.encoder_hidden_dim
            self.fc1 = nn.Linear(2*obs_dim+action_dim+1, self.hidden_dim)
            self.fc2 = nn.Linear(self.hidden_dim, self.hidden_dim)
            self.fc3 = nn.Linear(self.hidden_dim, self.encoding_dim)
    
    def forward(self, obs, action, next_obs, reward, task):
        need_reshape = False
        if len(obs.shape)==4:
            need_reshape = True
            bs, max_t, n_agents, _ = obs.shape
            obs = obs.reshape(bs * max_t * n_agents, -1)
            action = action.reshape(bs * max_t * n_agents, -1)
            next_obs = next_obs.reshape(bs * max_t * n_agents, -1)
            reward = reward.unsqueeze(2).repeat(1,1,n_agents,1).reshape(bs * max_t * n_agents, -1)
        else:
            assert len(obs.shape)==2, "Input obs should be 2D or 4D tensor"
        
        if self.args.mlp_encoding:
            x = th.cat([obs,action,next_obs,reward], dim=-1)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            transition_encoding = self.fc3(x)
            if need_reshape:
                transition_encoding = transition_encoding.reshape(bs, max_t, n_agents, -1)
            return transition_encoding


        task_decomposer = self.task2decomposer[task]

        own_obs, enemy_feats, ally_feats = task_decomposer.decompose_obs(obs)
        next_own_obs, next_enemy_feats, next_ally_feats = task_decomposer.decompose_obs(next_obs)
        # print(action.shape)
        no_attack_action, attack_action_info, attack_action_ls, _ = task_decomposer.decompose_action(action)

        if not self.direct_enemy_action_embedding and np.prod(attack_action_info.shape) > 0:
            attack_action_info = attack_action_info.transpose(0, 1).unsqueeze(-1)
            enemy_feats = th.cat([th.stack(enemy_feats, dim=0), attack_action_info], dim=-1)
        else:
            enemy_feats = th.stack(enemy_feats, dim=0)
        
        ally_feats = th.stack(ally_feats, dim=0)

        next_enemy_feats = th.stack(next_enemy_feats, dim=0)
        next_ally_feats = th.stack(next_ally_feats, dim=0)

        own_hidden = self.own_value(own_obs).unsqueeze(1)
        ally_hidden = self.ally_value(ally_feats).permute(1, 0, 2)
        enemy_hidden = self.enemy_value(enemy_feats).permute(1, 0, 2)

        next_own_hidden = self.next_own_value(next_own_obs).unsqueeze(1)
        next_ally_hidden = self.next_ally_value(next_ally_feats).permute(1, 0, 2)
        next_enemy_hidden = self.next_enemy_value(next_enemy_feats).permute(1, 0, 2)

        # no_attack_action_hidden = self.no_attack_action_value(no_attack_action)

        if self.direct_enemy_action_embedding:
            assert False, "Should not reach here!"
            attack_action_feats = th.stack(attack_action_ls, dim=0)
            attack_action_hidden = self.attack_action_value(attack_action_feats).permute(1, 0, 2)

        # ally time embedding
        bs_new, ally_seq_len, _ = ally_hidden.shape
        ally_steps = th.arange(ally_seq_len, device=ally_hidden.device).long()  # (seq_len,)
        ally_step_emb = ally_steps.view(1, ally_seq_len).expand(bs_new, -1)
        ally_step_emb = self.ally_time_embed(ally_step_emb)  # (bs, seq_len, entity_embed_dim)
        ally_hidden = ally_hidden + ally_step_emb

        next_ally_steps = th.arange(ally_seq_len, device=ally_hidden.device).long()  # (seq_len,)
        next_ally_step_emb = next_ally_steps.view(1, ally_seq_len).expand(bs_new, -1)
        next_ally_step_emb = self.next_ally_time_embed(next_ally_step_emb)  # (bs, seq_len, entity_embed_dim)
        next_ally_hidden = next_ally_hidden + next_ally_step_emb

        # enemy time embedding
        bs_new, enemy_seq_len, _ = enemy_hidden.shape
        enemy_steps = th.arange(enemy_seq_len, device=enemy_hidden.device).long()  # (seq_len,)
        enemy_step_emb = enemy_steps.view(1, enemy_seq_len).expand(bs_new, -1)
        enemy_step_emb = self.enemy_time_embed(enemy_step_emb)  # (bs, seq_len, entity_embed_dim)
        enemy_hidden = enemy_hidden + enemy_step_emb
        if self.direct_enemy_action_embedding:
            attack_action_hidden = attack_action_hidden + enemy_step_emb
        
        next_enemy_steps = th.arange(enemy_seq_len, device=enemy_hidden.device).long()  # (seq_len,)
        next_enemy_step_emb = next_enemy_steps.view(1, enemy_seq_len).expand(bs_new, -1)
        next_enemy_step_emb = self.next_enemy_time_embed(next_enemy_step_emb)  # (bs, seq_len, entity_embed_dim)
        next_enemy_hidden = next_enemy_hidden + next_enemy_step_emb
        
        # reward_hidden = self.reward_value(reward)

        
        total_obs_hidden = th.cat([own_hidden, enemy_hidden, ally_hidden], dim=1)
        total_next_obs_hidden = th.cat([next_own_hidden, next_enemy_hidden, next_ally_hidden], dim=1)

        own_length = 1
        enemy_length = enemy_hidden.shape[1]
        ally_length = ally_hidden.shape[1]
        
        obs_outputs = self.obs_transformer(total_obs_hidden, None)
        next_obs_outputs = self.next_obs_transformer(total_next_obs_hidden, None)

        obs_enemy = th.max(obs_outputs[:,1:1+enemy_length,:], dim=1)[0]
        obs_ally = th.max(obs_outputs[:,1+enemy_length:1+enemy_length+ally_length,:], dim=1)[0]
        obs_out = th.cat([obs_outputs[:,0,:], obs_enemy, obs_ally], dim=-1)

        next_obs_enemy = th.max(next_obs_outputs[:,1:1+enemy_length,:], dim=1)[0]
        next_obs_ally = th.max(next_obs_outputs[:,1+enemy_length:1+enemy_length+ally_length,:], dim=1)[0]
        next_obs_out = th.cat([next_obs_outputs[:,0,:], next_obs_enemy, next_obs_ally], dim=-1)
        # print(obs_out.shape, no_attack_action.shape, reward.shape)

        if self.is_club:
            encoding_inputs = th.cat([obs_out, no_attack_action], dim=-1)
        else:
            encoding_inputs = th.cat([obs_out, no_attack_action, next_obs_out, reward], dim=-1)
        x = F.relu(self.fc(encoding_inputs))
        x = F.relu(self.mid(x))
        transition_encoding = self.out(x)
        if need_reshape:
            transition_encoding = transition_encoding.reshape(bs, max_t, n_agents, -1)
        
        return transition_encoding


class TransformerTransitionRoleEncoder(nn.Module):
    def __init__(self, task2decomposer, args, is_club=False) -> None:
        super(TransformerTransitionRoleEncoder, self).__init__()
        
        self.args = args
        self.task2decomposer = task2decomposer
        self.is_club = is_club

        self.entity_embed_dim = args.entity_embed_dim
        self.attn_embed_dim = args.attn_embed_dim
        self.encoding_dim = args.transition_encoding_dim

        task_0 = list(task2decomposer.keys())[0]
        obs_own_dim = task2decomposer[task_0].own_obs_dim
        obs_en_dim, obs_al_dim = task2decomposer[task_0].obs_nf_en, task2decomposer[task_0].obs_nf_al
        n_actions_no_attack = task2decomposer[task_0].n_actions_no_attack
        
        has_attack_action = n_actions_no_attack != task2decomposer[task_0].n_actions
        self.direct_enemy_action_embedding = False
        if has_attack_action:
            self.direct_enemy_action_embedding = args.direct_enemy_action_embedding

        self.ally_value = nn.Linear(obs_al_dim, self.entity_embed_dim)
        if not self.direct_enemy_action_embedding and has_attack_action:
            self.enemy_value = nn.Linear(obs_en_dim + 1, self.entity_embed_dim)
        else:
            self.enemy_value = nn.Linear(obs_en_dim, self.entity_embed_dim)
        self.own_value = nn.Linear(obs_own_dim, self.entity_embed_dim)

        self.no_attack_action_value = nn.Linear(n_actions_no_attack, self.entity_embed_dim)
        self.attack_action_value = nn.Linear(1, self.entity_embed_dim)

        self.reward_value = nn.Linear(1, self.entity_embed_dim)

        # self.time_embed = nn.Embedding(max_seq_len, self.entity_embed_dim)
        max_ally_num = args.max_ally_num
        self.ally_time_embed = nn.Embedding(max_ally_num, self.entity_embed_dim)
        self.next_ally_time_embed = nn.Embedding(max_ally_num, self.entity_embed_dim)

        max_enemy_num = args.max_enemy_num
        self.enemy_time_embed = nn.Embedding(max_enemy_num, self.entity_embed_dim)
        self.next_enemy_time_embed = nn.Embedding(max_enemy_num, self.entity_embed_dim)

        self.obs_transformer = Transformer(self.entity_embed_dim, args.head, args.depth, self.entity_embed_dim)

        self.hidden_dim = args.encoder_hidden_dim
        action_dim = task2decomposer[task_0].n_actions_no_attack

        input_dim = 3 * self.entity_embed_dim + action_dim
        self.fc = nn.Linear(input_dim, self.hidden_dim)
        self.mid = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.out = nn.Linear(self.hidden_dim, self.encoding_dim)

        if self.args.mlp_encoding:
            obs_dim = task2decomposer[task_0].obs_dim
            action_dim = task2decomposer[task_0].n_actions
            self.hidden_dim = args.encoder_hidden_dim
            self.fc1 = nn.Linear(obs_dim+action_dim, self.hidden_dim)
            self.fc2 = nn.Linear(self.hidden_dim, self.hidden_dim)
            self.fc3 = nn.Linear(self.hidden_dim, self.encoding_dim)
    
    def forward(self, obs, action, task):
        need_reshape = False
        reshape_shape = -1
        if len(obs.shape)==4:
            reshape_shape = 4
            need_reshape = True
            bs, max_t, n_agents, _ = obs.shape
            obs = obs.reshape(bs * max_t * n_agents, -1)
            action = action.reshape(bs * max_t * n_agents, -1)
        elif len(obs.shape)==3:
            reshape_shape = 3
            need_reshape = True
            bs, max_t, _ = obs.shape
            obs = obs.reshape(bs * max_t, -1)
            action = action.reshape(bs * max_t, -1)
        else:
            assert len(obs.shape)==2, "Input obs should be 2D or 3D tensor"
        
        if self.args.mlp_encoding:
            x = th.cat([obs,action], dim=-1)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            transition_encoding = self.fc3(x)
            if need_reshape:
                transition_encoding = transition_encoding.reshape(bs, max_t, -1)
            return transition_encoding


        task_decomposer = self.task2decomposer[task]

        own_obs, enemy_feats, ally_feats = task_decomposer.decompose_obs(obs)
        no_attack_action, attack_action_info, attack_action_ls, _ = task_decomposer.decompose_action(action)

        if not self.direct_enemy_action_embedding and np.prod(attack_action_info.shape) > 0:
            attack_action_info = attack_action_info.transpose(0, 1).unsqueeze(-1)
            enemy_feats = th.cat([th.stack(enemy_feats, dim=0), attack_action_info], dim=-1)
        else:
            enemy_feats = th.stack(enemy_feats, dim=0)
        
        ally_feats = th.stack(ally_feats, dim=0)

        own_hidden = self.own_value(own_obs).unsqueeze(1)
        ally_hidden = self.ally_value(ally_feats).permute(1, 0, 2)
        enemy_hidden = self.enemy_value(enemy_feats).permute(1, 0, 2)

        # no_attack_action_hidden = self.no_attack_action_value(no_attack_action)

        if self.direct_enemy_action_embedding:
            assert False, "Should not reach here!"
            attack_action_feats = th.stack(attack_action_ls, dim=0)
            attack_action_hidden = self.attack_action_value(attack_action_feats).permute(1, 0, 2)

        # ally time embedding
        bs_new, ally_seq_len, _ = ally_hidden.shape
        ally_steps = th.arange(ally_seq_len, device=ally_hidden.device).long()  # (seq_len,)
        ally_step_emb = ally_steps.view(1, ally_seq_len).expand(bs_new, -1)
        ally_step_emb = self.ally_time_embed(ally_step_emb)  # (bs, seq_len, entity_embed_dim)
        ally_hidden = ally_hidden + ally_step_emb

        # enemy time embedding
        bs_new, enemy_seq_len, _ = enemy_hidden.shape
        enemy_steps = th.arange(enemy_seq_len, device=enemy_hidden.device).long()  # (seq_len,)
        enemy_step_emb = enemy_steps.view(1, enemy_seq_len).expand(bs_new, -1)
        enemy_step_emb = self.enemy_time_embed(enemy_step_emb)  # (bs, seq_len, entity_embed_dim)
        enemy_hidden = enemy_hidden + enemy_step_emb
        if self.direct_enemy_action_embedding:
            attack_action_hidden = attack_action_hidden + enemy_step_emb

        
        total_obs_hidden = th.cat([own_hidden, enemy_hidden, ally_hidden], dim=1)

        own_length = 1
        enemy_length = enemy_hidden.shape[1]
        ally_length = ally_hidden.shape[1]
        
        obs_outputs = self.obs_transformer(total_obs_hidden, None)

        obs_enemy = th.max(obs_outputs[:,1:1+enemy_length,:], dim=1)[0]
        obs_ally = th.max(obs_outputs[:,1+enemy_length:1+enemy_length+ally_length,:], dim=1)[0]
        obs_out = th.cat([obs_outputs[:,0,:], obs_enemy, obs_ally], dim=-1)

        encoding_inputs = th.cat([obs_out, no_attack_action], dim=-1)
        x = F.relu(self.fc(encoding_inputs))
        x = F.relu(self.mid(x))
        transition_encoding = self.out(x)
        if need_reshape:
            if reshape_shape == 4:
                transition_encoding = transition_encoding.reshape(bs, max_t, n_agents, -1)
            else:
                transition_encoding = transition_encoding.reshape(bs, max_t, -1)
        
        return transition_encoding

class TransformerPriorRoleEncoder(nn.Module):
    def __init__(self, task2decomposer, args, is_club=False) -> None:
        super(TransformerPriorRoleEncoder, self).__init__()
        
        self.args = args
        self.task2decomposer = task2decomposer
        self.is_club = is_club

        self.entity_embed_dim = args.entity_embed_dim
        self.attn_embed_dim = args.attn_embed_dim
        self.encoding_dim = args.transition_encoding_dim

        task_0 = list(task2decomposer.keys())[0]
        obs_own_dim = task2decomposer[task_0].own_obs_dim
        obs_en_dim, obs_al_dim = task2decomposer[task_0].obs_nf_en, task2decomposer[task_0].obs_nf_al
        n_actions_no_attack = task2decomposer[task_0].n_actions_no_attack
        
        has_attack_action = n_actions_no_attack != task2decomposer[task_0].n_actions
        self.direct_enemy_action_embedding = False
        if has_attack_action:
            self.direct_enemy_action_embedding = args.direct_enemy_action_embedding

        self.ally_value = nn.Linear(obs_al_dim, self.entity_embed_dim)
        if self.direct_enemy_action_embedding:
            self.enemy_value = nn.Linear(obs_en_dim + 1, self.entity_embed_dim)
        else:
            self.enemy_value = nn.Linear(obs_en_dim, self.entity_embed_dim)
        self.own_value = nn.Linear(obs_own_dim, self.entity_embed_dim)

        self.no_attack_action_value = nn.Linear(n_actions_no_attack, self.entity_embed_dim)
        self.attack_action_value = nn.Linear(1, self.entity_embed_dim)

        self.reward_value = nn.Linear(1, self.entity_embed_dim)

        # self.time_embed = nn.Embedding(max_seq_len, self.entity_embed_dim)
        max_ally_num = args.max_ally_num
        self.ally_time_embed = nn.Embedding(max_ally_num, self.entity_embed_dim)
        self.next_ally_time_embed = nn.Embedding(max_ally_num, self.entity_embed_dim)

        max_enemy_num = args.max_enemy_num
        self.enemy_time_embed = nn.Embedding(max_enemy_num, self.entity_embed_dim)
        self.next_enemy_time_embed = nn.Embedding(max_enemy_num, self.entity_embed_dim)

        self.obs_transformer = Transformer(self.entity_embed_dim, args.head, args.depth, self.entity_embed_dim)

        self.hidden_dim = args.encoder_hidden_dim
        task_embedding_dim = args.encoding_dim

        input_dim = 3 * self.entity_embed_dim + task_embedding_dim
        self.fc = nn.Linear(input_dim, self.hidden_dim)
        self.mid = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.out = nn.Linear(self.hidden_dim, self.encoding_dim)

        if self.args.mlp_encoding:
            obs_dim = task2decomposer[task_0].obs_dim
            self.hidden_dim = args.encoder_hidden_dim
            self.fc1 = nn.Linear(obs_dim+task_embedding_dim, self.hidden_dim)
            self.fc2 = nn.Linear(self.hidden_dim, self.hidden_dim)
            self.fc3 = nn.Linear(self.hidden_dim, self.encoding_dim)
    
    def init_hidden(self):
        # make hidden states on the same device as model
        return self.fc.weight.new(1, self.entity_embed_dim).zero_()

    def forward(self, obs, task_embedding, task, hidden_state=None):
        need_reshape = 0
        if hidden_state is not None:
            hidden_state = hidden_state.view(-1, 1, self.entity_embed_dim)
        if len(obs.shape)==4:
            need_reshape = 1
            bs, max_t, n_agents, _ = obs.shape
            obs = obs.reshape(bs * max_t * n_agents, -1)
            task_embedding = task_embedding.unsqueeze(1).repeat(1, max_t, 1, 1).reshape(bs * max_t * n_agents, -1)
        elif len(obs.shape)==3:
            need_reshape = 2
            bs, n_agents, _ = obs.shape
            obs = obs.reshape(bs * n_agents, -1)
            task_embedding = task_embedding.reshape(bs * n_agents, -1)
        else:
            assert len(obs.shape)==2, "Input obs should be 2D or 3D tensor"

        
        if self.args.mlp_encoding:
            x = th.cat([obs,task_embedding], dim=-1)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            transition_encoding = self.fc3(x)
            if need_reshape:
                transition_encoding = transition_encoding.reshape(bs, max_t, -1)
            return transition_encoding


        task_decomposer = self.task2decomposer[task]

        own_obs, enemy_feats, ally_feats = task_decomposer.decompose_obs(obs)

        enemy_feats = th.stack(enemy_feats, dim=0)
        
        ally_feats = th.stack(ally_feats, dim=0)

        own_hidden = self.own_value(own_obs).unsqueeze(1)
        ally_hidden = self.ally_value(ally_feats).permute(1, 0, 2)
        enemy_hidden = self.enemy_value(enemy_feats).permute(1, 0, 2)

        # no_attack_action_hidden = self.no_attack_action_value(no_attack_action)

        if self.direct_enemy_action_embedding:
            assert False, "Should not reach here!"
            attack_action_feats = th.stack(attack_action_ls, dim=0)
            attack_action_hidden = self.attack_action_value(attack_action_feats).permute(1, 0, 2)

        # ally time embedding
        bs_new, ally_seq_len, _ = ally_hidden.shape
        ally_steps = th.arange(ally_seq_len, device=ally_hidden.device).long()  # (seq_len,)
        ally_step_emb = ally_steps.view(1, ally_seq_len).expand(bs_new, -1)
        ally_step_emb = self.ally_time_embed(ally_step_emb)  # (bs, seq_len, entity_embed_dim)
        ally_hidden = ally_hidden + ally_step_emb

        # enemy time embedding
        bs_new, enemy_seq_len, _ = enemy_hidden.shape
        enemy_steps = th.arange(enemy_seq_len, device=enemy_hidden.device).long()  # (seq_len,)
        enemy_step_emb = enemy_steps.view(1, enemy_seq_len).expand(bs_new, -1)
        enemy_step_emb = self.enemy_time_embed(enemy_step_emb)  # (bs, seq_len, entity_embed_dim)
        enemy_hidden = enemy_hidden + enemy_step_emb
        if self.direct_enemy_action_embedding:
            attack_action_hidden = attack_action_hidden + enemy_step_emb

        history_hidden = hidden_state
        if history_hidden is not None:
            total_obs_hidden = th.cat([own_hidden, enemy_hidden, ally_hidden, history_hidden], dim=1)
        else:
            total_obs_hidden = th.cat([own_hidden, enemy_hidden, ally_hidden], dim=1)

        own_length = 1
        enemy_length = enemy_hidden.shape[1]
        ally_length = ally_hidden.shape[1]
        
        obs_outputs = self.obs_transformer(total_obs_hidden, None)

        h = obs_outputs[:, -1:, :]
        obs_enemy = th.max(obs_outputs[:,1:1+enemy_length,:], dim=1)[0]
        obs_ally = th.max(obs_outputs[:,1+enemy_length:1+enemy_length+ally_length,:], dim=1)[0]
        obs_out = th.cat([obs_outputs[:,0,:], obs_enemy, obs_ally], dim=-1)

        encoding_inputs = th.cat([obs_out, task_embedding], dim=-1)
        x = F.relu(self.fc(encoding_inputs))
        x = F.relu(self.mid(x))
        transition_encoding = self.out(x)
        if need_reshape == 1:
            transition_encoding = transition_encoding.reshape(bs, max_t, n_agents, -1)
        elif need_reshape == 2:
            transition_encoding = transition_encoding.reshape(bs, n_agents, -1)
        
        return transition_encoding, h