# RNN agent with adaptors
import os
import numpy as np
from functools import partial
from copy import deepcopy
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from torch.func import stack_module_state, vmap
from utils.gp_nets_simple import FCNet
from utils.calc import count_total_parameters
from utils.embed import polynomial_embed, binary_embed
from utils.rl_utils import RunningMeanStd, EMAMeanStd
from utils.cont_utils import unif_wo_replace_np

# ------------------------------ Basic Components ------------------------------
    
class StateAttnFeatureExtractor(nn.Module):
    def __init__(self, task2decomposer, args) -> None:
        super(StateAttnFeatureExtractor, self).__init__()
        self.task2decomposer = task2decomposer
        self.surrogate_decomposer = surrogate_decomposer = deepcopy(task2decomposer[list(task2decomposer.keys())[0]])
        self.args = args
        
        self.embed_dim = args.mixing_embed_dim
        self.attn_embed_dim = args.attn_embed_dim
        self.entity_embed_dim = args.entity_embed_dim

        match self.args.env:
            case 'sc2' | 'sc2_v2':
                state_nf_al, state_nf_en, timestep_state_dim = \
                    surrogate_decomposer.aligned_state_nf_al, surrogate_decomposer.aligned_state_nf_en, surrogate_decomposer.timestep_number_state_dim
            case 'gymma' | 'grid_mpe' | 'mamujoco':
                state_nf_al, state_nf_en, timestep_state_dim = \
                    surrogate_decomposer.state_nf_al, surrogate_decomposer.state_nf_en, surrogate_decomposer.timestep_number_state_dim
            case _:
                raise NotImplementedError(f"Env {self.args.env} not implemented.")
        # timestep_state_dim = 0/1 denote whether encode the "t" of s
        # get detailed state shape information
        self.state_last_action, self.state_timestep_number = surrogate_decomposer.state_last_action, surrogate_decomposer.state_timestep_number
        
        # get action dimension information
        self.n_actions_no_attack = surrogate_decomposer.n_actions_no_attack

        # define state information processor
        if self.state_last_action:
            state_nf_al += self.n_actions_no_attack + 1
        self.ally_encoder = nn.Linear(state_nf_al, self.entity_embed_dim)
        self.enemy_encoder = nn.Linear(state_nf_en, self.entity_embed_dim)
        
        self.query = nn.Linear(self.entity_embed_dim, self.attn_embed_dim)
        self.key = nn.Linear(self.entity_embed_dim, self.attn_embed_dim)
        
        self.cls_token = nn.Parameter(th.randn(1, 1, 1, self.entity_embed_dim))

        mixing_input_dim = self.entity_embed_dim
        entity_mixing_input_dim = self.entity_embed_dim + self.entity_embed_dim
        if self.state_timestep_number:
            mixing_input_dim += timestep_state_dim
            entity_mixing_input_dim += timestep_state_dim
        
    def forward(self, states, task):
        # agent_qs: [batch_size, seq_len, n_agents]
        # states: [batch_size, seq_len, state_dim]
        task_decomposer = self.task2decomposer[task]
        
        bs, seq_len, _ = states.size()
        n_agents, n_enemies = task_decomposer.n_agents, task_decomposer.n_enemies
        n_entities = n_agents + n_enemies

        # get decomposed state information
        ally_states, enemy_states, last_action_states, timestep_number_state = task_decomposer.decompose_state(states)
        ally_states = th.stack(ally_states, dim=0)  # [n_agents, bs, seq_len, state_nf_al]
        enemy_states = th.stack(enemy_states, dim=0)  # [n_enemies, bs, seq_len, state_nf_en]

        # stack action information
        if self.state_last_action:
            last_action_states = th.stack(last_action_states, dim=0) # (n_agents, bs, seq_len, n_actions)
            _, _, last_compact_action_states = task_decomposer.decompose_action_info(last_action_states)
            ally_states = th.cat([ally_states, last_compact_action_states], dim=-1)
        
        # do inference and get entity_embed
        ally_embed = self.ally_encoder(ally_states) # [n_agents, bs, seq_len, entity_embed_dim]
        enemy_embed = self.enemy_encoder(enemy_states)

        cls_token = self.cls_token.expand(1, bs, seq_len, -1) # differentiating agent numbers
        
        # we ought to do self-attention
        entity_embed = th.cat([cls_token, ally_embed, enemy_embed], dim=0) # [n_entity, bs, seq_len, entity_embed_dim]
        n_entities = n_entities + 1

        # do attention
        proj_query = self.query(entity_embed).permute(1, 2, 0, 3).reshape(bs * seq_len, n_entities, self.attn_embed_dim)
        proj_key = self.key(entity_embed).permute(1, 2, 0, 3).reshape(bs * seq_len, n_entities, self.attn_embed_dim)
        energy = th.bmm(proj_query, proj_key.transpose(1, 2)) / (self.attn_embed_dim ** (1 / 2))
        score = F.softmax(energy, dim=-1) # (bs*seq_len, n_entities, n_entities)
        proj_value = entity_embed.permute(1, 2, 0, 3).reshape(bs * seq_len, n_entities, self.entity_embed_dim)
        out = th.bmm(score, proj_value) # (bs * seq_len, n_entities, entity_embed_dim)
        
        # cls_feature instead of mean pooling over entity
        cls_feature = out[:, 0, :].reshape(bs, seq_len, self.entity_embed_dim)
        return cls_feature
    
        # mean pooling over entity
        # out = out.mean(dim=1).reshape(bs, seq_len, self.entity_embed_dim)
        # return out, cls_feature
        
        
class ObsAttnFeatureExtractor(nn.Module):
    # o: [bs, T, n, d_n] --> [bs, T, n, d], independently for [bs, T, n]
    def __init__(self, task2input_shape_info, task2decomposer, task2n_agents, surrogate_decomposer, args):
        super().__init__()
        self.args = args
        self.mixing_embed_dim = args.mixing_embed_dim
        self.entity_embed_dim = args.entity_embed_dim
        self.attn_embed_dim = args.attn_embed_dim
        self.task2decomposer = task2decomposer
        self.task2n_agents = task2n_agents
        self.task2last_action_shape = {
            task: task2input_shape_info[task]["last_action_shape"]
            for task in task2input_shape_info
        }
        self.have_attack_action = (surrogate_decomposer.n_actions != surrogate_decomposer.n_actions_no_attack)
        self.state_last_action, self.state_timestep_number = surrogate_decomposer.state_last_action, surrogate_decomposer.state_timestep_number
        self.output_dim = args.entity_embed_dim

        ## get obs shape information
        match self.args.env:
            case "sc2" | "sc2_v2":
                obs_own_dim, obs_en_dim, obs_al_dim = (
                    surrogate_decomposer.aligned_own_obs_dim,
                    surrogate_decomposer.aligned_obs_nf_en,
                    surrogate_decomposer.aligned_obs_nf_al,
                )
                ## enemy_obs ought to add attack_action_infos
                if self.args.obs_last_action:
                    obs_en_dim += 1
            case "gymma" | "grid_mpe" | "mamujoco":
                obs_own_dim, obs_en_dim, obs_al_dim = (
                    surrogate_decomposer.own_obs_dim,
                    surrogate_decomposer.obs_nf_en,
                    surrogate_decomposer.obs_nf_al,
                )
                ## enemy_obs ought to add attack_action_infos
                if self.args.obs_last_action:
                    obs_en_dim += surrogate_decomposer.n_actions_attack
            case _:
                raise NotImplementedError
        
        # get action dimension information
        n_actions_no_attack = surrogate_decomposer.n_actions_no_attack
        wrapped_obs_own_dim = obs_own_dim + self.args.id_length
        
        if self.args.obs_last_action:
            if self.args.env not in ['grid_mpe']:
                wrapped_obs_own_dim += n_actions_no_attack + 1
            else:
                wrapped_obs_own_dim += n_actions_no_attack

        self.query = nn.Linear(wrapped_obs_own_dim, self.attn_embed_dim * self.args.head)
        self.ally_key = nn.Linear(obs_al_dim, self.attn_embed_dim * self.args.head)
        self.ally_value = nn.Linear(obs_al_dim, self.entity_embed_dim * self.args.head)
        self.enemy_key = nn.Linear(obs_en_dim, self.attn_embed_dim * self.args.head)
        self.enemy_value = nn.Linear(obs_en_dim, self.entity_embed_dim * self.args.head)
        self.own_value = nn.Linear(wrapped_obs_own_dim, self.entity_embed_dim * self.args.head)
        self.out_layer = nn.Linear(self.entity_embed_dim * self.args.head * 3, self.entity_embed_dim * self.args.head)

    def _multi_head_attention(self, q, k, v, attn_dim): # TODO masking
        """
            q: [bs*n_agents, attn_dim*n_heads]
            k: [bs*n_agents,n_entity, attn_dim*n_heads]
            v: [bs*n_agents, n_entity, value_dim*n_heads]
        """
        bs = q.shape[0]
        q = q.unsqueeze(1).view(bs, 1, self.args.head, self.attn_embed_dim)
        k = k.view(bs, -1, self.args.head, self.attn_embed_dim)
        v = v.view(bs, -1, self.args.head, self.entity_embed_dim)

        q = q.transpose(1, 2).contiguous().view(bs*self.args.head, 1, self.attn_embed_dim)
        k = k.transpose(1, 2).contiguous().view(bs*self.args.head, -1, self.attn_embed_dim)
        v = v.transpose(1, 2).contiguous().view(bs*self.args.head, -1, self.entity_embed_dim)

        energy = th.bmm(q, k.transpose(1, 2)) / (attn_dim ** (1 / 2))
        assert energy.shape[0] == bs * self.args.head and energy.shape[1] == 1
        # shape[2] == n_entity
        score = F.softmax(energy, dim=-1)
        out = th.bmm(score, v).view(bs, self.args.head, 1, self.entity_embed_dim) # (bs*head, 1, entity_embed_dim) 
        out = out.transpose(1, 2).contiguous().view(bs, 1, self.entity_embed_dim * self.args.head).squeeze(1)
        return out
        
    def _attention(self, q, k, v, attn_dim):
        """
            q: [bsn, t, attn_dim]
            k: [bsn, t, attn_dim]
            v: [bsn, t, value_dim]
        """
        assert self.args.head == 1
        if len(q.shape) == 2:
            q = q.unsqueeze(1)
        energy = th.bmm(q, k.transpose(1, 2))/(attn_dim ** (1 / 2))
        score = F.softmax(energy, dim=-1)
        out = th.bmm(score, v).squeeze(1)
        return out
    
    def forward(self, obs:th.Tensor, task):
        # CHECK pure obs or obs+lastact
        # NOTE: obs sequences are regarded as individual trajectories
        original_shape = obs.shape
        shape = obs.shape
        
        if len(shape) == 2: # [n, d]
            obs = obs.view(1, 1, *obs.shape)
        elif len(shape) == 3: # [bs, n, d]
            obs = obs.unsqueeze(1)

        bs, t, n, _ = obs.size()
        
        obs_flat = obs.permute(0, 2, 1, 3).contiguous().view(bs * n * t, -1)
        
        task_decomposer = self.task2decomposer[task]
        last_action_shape = self.task2last_action_shape[task]
        
        # get decomposed obs information
        obs_dim = task_decomposer.obs_dim
        obs_inputs = obs_flat[:, :obs_dim]
        last_action_inputs = obs_flat[:, obs_dim:obs_dim+last_action_shape]
        
        # decompose obs input
        own_obs, enemy_feats, ally_feats = task_decomposer.decompose_obs(obs_inputs)
        
        if task not in getattr(self, '_id_cache', {}):
            if not hasattr(self, '_id_cache'):
                self._id_cache = {}
            agent_ids = []
            for i in range(n):
                agent_id = th.as_tensor(binary_embed(i + 1, self.args.id_length, self.args.max_agent), 
                                      dtype=own_obs.dtype, device=own_obs.device)
                agent_ids.append(agent_id)
            self._id_cache[task] = th.stack(agent_ids, dim=0)
        
        agent_ids_tensor = self._id_cache[task].to(own_obs.device)
        
        agent_id_inputs = agent_ids_tensor.unsqueeze(0).expand(bs * t, -1, -1).reshape(bs * t * n, -1)
        _, attack_action_info, compact_action_states = task_decomposer.decompose_action_info(last_action_inputs)

        # incorporate agent_id embed and compact_action_states
        own_obs = th.cat([own_obs, agent_id_inputs, compact_action_states], dim=-1)

        # incorporate attack_action_info (bs*n_agents*T, n_enemies, 1) into enemy_feats
        enemy_feats = th.stack(enemy_feats, dim=1)
        if self.have_attack_action:
            attack_action_info = attack_action_info.unsqueeze(-1)
            enemy_feats = th.cat([enemy_feats, attack_action_info], dim=-1)
        # (bs*n_agents*T, n_enemies, obs_nf_en+1)
        ally_feats = th.stack(ally_feats, dim=1)
        
        # compute k, q, v for (multi-head) attention
        query = self.query(own_obs)
        ally_keys = self.ally_key(ally_feats)  # (bs * n * t, n_ally, attn_dim *n_heads)
        enemy_keys = self.enemy_key(enemy_feats)
        ally_values = self.ally_value(ally_feats)
        enemy_values = self.enemy_value(enemy_feats)
        own_feature = self.own_value(own_obs) #(bs * n * t, entity_embed_dim * n_heads)
        
        if self.args.head == 1:
            ally_feature = self._attention(query, ally_keys, ally_values, self.attn_embed_dim)
            enemy_feature = self._attention(query, enemy_keys, enemy_values, self.attn_embed_dim)
        else:
            ally_feature = self._multi_head_attention(query, ally_keys, ally_values, self.attn_embed_dim)
            enemy_feature = self._multi_head_attention(query, enemy_keys, enemy_values, self.attn_embed_dim)
        attn_feature = th.cat([own_feature, ally_feature, enemy_feature], dim=-1) # (bs * n * t, entity_embed_dim * n_heads * 3)
        attn_feature = self.out_layer(attn_feature)
        
        return attn_feature, enemy_feats
    
class ObsAttnGRUBase(nn.Module):
    # o: [bs, n, d_n] --> [bs, n, d]; o_CA + h -> h, independently for [bs, n]
    def __init__(self, task2input_shape_info, task2decomposer, task2n_agents, surrogate_decomposer, args):
        super(ObsAttnGRUBase, self).__init__()
        self.args = args
        self.entity_embed_dim = args.entity_embed_dim
        self.task2n_agents = task2n_agents
        self.attn_enc = ObsAttnFeatureExtractor(task2input_shape_info,
                                        task2decomposer, task2n_agents,
                                        surrogate_decomposer, args)
        self.rnn = nn.GRUCell(self.entity_embed_dim * self.args.head, self.args.rnn_hidden_dim)
    
    def forward(self, inputs, hidden_state, task):
        attn_feature, enemy_feats = self.attn_enc(inputs, task)
        
        if len(attn_feature.shape) != 2:
            if len(inputs.shape) == 4:  # [bs, t, n, d]
                bs, t, n = inputs.shape[:3]
                attn_feature = attn_feature.view(bs * n * t, -1)
            elif len(inputs.shape) == 3:  # [bs, n, d]
                bs, n = inputs.shape[:2]
                attn_feature = attn_feature.view(bs * n, -1)
            else:  # [n, d]
                n = inputs.shape[0]
                attn_feature = attn_feature.view(n, -1)
        
        if hidden_state.is_contiguous():
            h_in = hidden_state.view(-1, self.args.rnn_hidden_dim)
        else:
            h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim)
        h = self.rnn(attn_feature, h_in) # (bs*n_agents, y)

        return h, enemy_feats

class ObsTrajXformer(nn.Module):
    def __init__(self, task2input_shape_info, task2decomposer, task2n_agents, surrogate_decomposer, args):
        super().__init__()
        self.args = args
        self.entity_embed_dim = args.entity_embed_dim
        self.attn_embed_dim = args.attn_embed_dim
        self.output_dim = args.rnn_hidden_dim
        self.task2n_agents = task2n_agents
        
        inputs = [task2input_shape_info, task2decomposer, task2n_agents, surrogate_decomposer, args]
        self.obs_attn_enc = ObsAttnFeatureExtractor(*inputs)

        self.temp_query = nn.Linear(self.entity_embed_dim, self.attn_embed_dim)
        self.temp_key = nn.Linear(self.entity_embed_dim, self.attn_embed_dim)
        self.traj_feat_trfm = FCNet(self.entity_embed_dim, self.entity_embed_dim, hidden_layer=2, hidden_dim=self.entity_embed_dim, use_last_activ=True)
        self.traj_ln1 = nn.LayerNorm(self.entity_embed_dim)
        self.traj_ln2 = nn.LayerNorm(self.entity_embed_dim)
        
        self.out_proj = FCNet(self.entity_embed_dim, self.output_dim, hidden_layer=2, use_last_activ=True)
        
    def _attention(self, q, k, v, attn_dim, mask=None):
        """
            q: [bsn, t, attn_dim]
            k: [bsn, t, attn_dim]
            v: [bsn, t, value_dim]
            mask: [bsn, t, 1]
        """
        assert self.args.head == 1
        if len(q.shape) == 2:
            q = q.unsqueeze(1)
        energy = th.bmm(q, k.transpose(1, 2))/(attn_dim ** (1 / 2))
        if mask is not None:
            mask = mask.float()
            attn_mask = 1 - th.bmm(mask, mask.transpose(1, 2)) # [bsn, t, t]
            energy = energy - attn_mask * (1e6)
        score = F.softmax(energy, dim=-1)
        out = th.bmm(score, v).squeeze(1)
        return out
    
    def forward(self, obs:th.Tensor, mask, task):
        # NOTE: obs sequences are regarded as individual trajectories
        # obs: [bs, T, n, *] -> [bs, d]; mask: [bs, T, 1]
        bs, t, n, _ = obs.size()
        
        attn_feature, _ = self.obs_attn_enc.forward(obs, task) # [bs*n, t, d]
        attn_feature = attn_feature.reshape(bs*n, t, -1)
        mask = mask.repeat(n, 1, 1) # (bsn, T, 1)
        
        # do temporal attention on trajectory level
        entity_query = self.temp_query(attn_feature)
        entity_key = self.temp_key(attn_feature)
        entity_value = attn_feature
        entity_attn = self._attention(entity_query, entity_key, entity_value, self.attn_embed_dim, mask) * mask
        entity_ln1 = self.traj_ln1(entity_attn + entity_value)
        entity_ln2 = self.traj_ln2(entity_ln1 + self.traj_feat_trfm(entity_ln1))
        entity_out = entity_ln2.reshape(bs, n, t, -1).permute(0, 2, 1, 3)
        
        out = self.out_proj(entity_out) # [bs, t, n, output_dim]

        return out

class ActionAttnFeatureExtractor(nn.Module):
    def __init__(self, task2decomposer, task2n_agents, surrogate_decomposer, args):
        super(ActionAttnFeatureExtractor, self).__init__()
        self.args = args
        self.task2decomposer = task2decomposer
        self.task2n_agents = task2n_agents
        self.surrogate_decomposer = surrogate_decomposer
        
        self.entity_embed_dim = args.entity_embed_dim
        self.attn_embed_dim = args.attn_embed_dim
        
        # Get action dimension information
        self.n_actions_no_attack = surrogate_decomposer.n_actions_no_attack
        self.have_attack_action = (surrogate_decomposer.n_actions != surrogate_decomposer.n_actions_no_attack)
        
        # For action_own (fixed size part)
        self.action_own_encoder = nn.Linear(self.n_actions_no_attack, self.entity_embed_dim)
        
        # For action_interact (variable size part)
        if self.have_attack_action > 0:
            self.action_interact_encoder = nn.Linear(self.n_actions_no_attack + 1, self.entity_embed_dim)
            self.query = nn.Linear(self.entity_embed_dim, self.attn_embed_dim)
            self.key = nn.Linear(self.entity_embed_dim, self.attn_embed_dim)
        else:
            # For environments without attack actions, we still need these for consistency
            self.action_interact_encoder = None
            self.query = None
            self.key = None

    def forward(self, actions, task):
        task_decomposer = self.task2decomposer[task]
        n_agents = self.task2n_agents[task]
        
        shape = actions.shape
        bs, T, n_agents, n_actions = shape
        
        action_own, action_interact, action_compact = task_decomposer.decompose_action_info(actions)
        
        # action_own: [bs, T, n_agents, n_actions_no_attack]
        action_own_features = self.action_own_encoder(action_own)  # [bs, T, n_agents, entity_embed_dim]
        
        if self.have_attack_action:
            action_interact_flat = action_compact.view(bs*T, n_agents, self.n_actions_no_attack + 1)  # [bsT, n, n_actions_no_attack + 1]
            action_interact_encoded = self.action_interact_encoder(action_interact_flat)  # [bsT, n, entity_embed_dim]
            
            # Apply self-attention across agents
            proj_query = self.query(action_interact_encoded)  # [bs*T, n_agents, attn_embed_dim]
            proj_key = self.key(action_interact_encoded)      # [bs*T, n_agents, attn_embed_dim]
            
            # Compute attention scores
            energy = th.bmm(proj_query, proj_key.transpose(1, 2)) / (self.attn_embed_dim ** (1 / 2))
            score = F.softmax(energy, dim=-1)  # [bsT, n_agents, n_agents]
            
            # Apply attention weights
            action_interact_features = th.bmm(score, action_interact_encoded).view(bs, T, n_agents, self.entity_embed_dim)  # [bs, T, n_agents, entity_embed_dim]
        else:
            # If no attack actions, create zero features
            action_interact_features = th.zeros_like(action_own_features)
        
        # Combine both parts
        combined_features = action_own_features + action_interact_features  # [bsT, n_agents, entity_embed_dim]
        combined_features = combined_features.reshape(*shape[:-1], -1)
        return combined_features

class OutputHead(nn.Module):
    # [bs, n, d] --> [bs, n, n_actions] (e.g. Q func for discrete action or pi func for continuous action)
    def __init__(self, surrogate_decomposer, args, input_dim=None):
        super().__init__()
        self.args = args
        self.entity_embed_dim = args.entity_embed_dim
        self.hidden_dim = args.rnn_hidden_dim
        if input_dim is None:
            input_dim = self.hidden_dim
        self.have_attack_action = (surrogate_decomposer.n_actions != surrogate_decomposer.n_actions_no_attack)
        ## get obs shape information
        match self.args.env:
            case "sc2" | "sc2_v2":
                obs_en_dim = surrogate_decomposer.aligned_obs_nf_en
                obs_en_dim += 1
                n_actions_no_attack = surrogate_decomposer.n_actions_no_attack
            case "gymma" | "grid_mpe" | "mamujoco":
                obs_en_dim = surrogate_decomposer.obs_nf_en + surrogate_decomposer.n_actions_attack
                n_actions_no_attack = surrogate_decomposer.n_actions_no_attack
            case _:
                raise NotImplementedError
        self.n_actions_no_attack = n_actions_no_attack
        self.id_action_no_attack = th.eye(self.n_actions_no_attack, dtype=th.float32, device=args.device)
            
        policy_hidden_dim = getattr(self.args, 'pa_hidden_dim', 64)
        policy_hidden_layer = getattr(self.args, 'pa_hidden_layer', 3)
            
        self.no_attack_layer = FCNet(input_dim, self.n_actions_no_attack, hidden_layer=policy_hidden_layer, hidden_dim=policy_hidden_dim, use_layer_norm=True)
        if self.have_attack_action:
            self.attack_layer = FCNet(input_dim + self.hidden_dim, 1, hidden_layer=policy_hidden_layer, hidden_dim=policy_hidden_dim, use_layer_norm=True)
            self.enemy_embed = FCNet(obs_en_dim, self.args.rnn_hidden_dim, hidden_layer=policy_hidden_layer - 1, hidden_dim=policy_hidden_dim, use_layer_norm=True)
            
        print("Output Head init...")
        count_total_parameters(self, is_concrete=True)
    
    def forward(self, h, enemy_feats):
        wo_action_act = self.no_attack_layer(h) # (bsn, n_no_attack)
        
        if self.have_attack_action:
            enemy_feature = self.enemy_embed(enemy_feats) # (bsn, n_enemy, hid)
            attack_action_input = th.cat(
                [
                    enemy_feature,
                    h.unsqueeze(1).expand(-1, enemy_feats.size(1), -1),
                ],
                dim=-1,
            )  # (bsn, n_enemy, 2*hid)
            attack_action_act = self.attack_layer(attack_action_input).squeeze(-1) # (bsn, n_enemy)
            act = th.cat([wo_action_act, attack_action_act], dim=1)
        else:
            act = wo_action_act
        # act: [bsn, n_act, d_phi]
        return act

class MultiHeadBase:
    def __init__(self, adaptor_func, device='cuda', max_head=4, include_cur=False, in_dim=1, batched_input=False):
        self.adaptor_func = adaptor_func
        self.device = device
        self.task2adaptors: dict[str, nn.Module] = {}
        self.cur_all_task = []
        self.max_head = max_head
        self.include_cur = include_cur
        self._vectorized_forward = None
        self._selected_tasks = []
        self._stacked_params = None
        self._stacked_buffers = None
        self._base_model = None
        self.in_dim = in_dim
        self.batched_input = batched_input
    
    def switch_adaptor(self, task):
        if task in self.task2adaptors.keys():
            self.cur_adaptor = self.task2adaptors[task]
            self.cur_adaptor_task = task
    
    def init_adaptor(self, task: str):
        if task not in self.task2adaptors:
            self.task2adaptors[task] = self.adaptor_func().to(self.device)
        self.cur_all_task = list(self.task2adaptors.keys())
        self._build_parallel_forward()
        return self.task2adaptors[task]
    
    def _build_parallel_forward(self):
        if len(self.cur_all_task) < 2:
            self._vectorized_forward = None
            self._selected_tasks = []
            self._stacked_params = None
            self._stacked_buffers = None
            self._base_model = None
            return
        
        all_tasks = self.cur_all_task
        available_tasks = [t for t in all_tasks if t in self.task2adaptors]
        
        if len(available_tasks) < 2:
            self._vectorized_forward = None
            self._selected_tasks = []
            self._stacked_params = None
            self._stacked_buffers = None
            self._base_model = None
            return
        
        n_total = len(available_tasks)
        n_cur_head = len(available_tasks) - 1
        n_reuse_head = min(self.max_head, n_cur_head) + int(self.include_cur)
        
        if n_total <= n_reuse_head:
            self._selected_tasks = available_tasks
        else:
            self._selected_tasks = available_tasks[-n_reuse_head:]
        
        selected_adaptors = [self.task2adaptors[task] for task in self._selected_tasks]
        
        self._base_model = selected_adaptors[0]
        self._stacked_params, self._stacked_buffers = stack_module_state(selected_adaptors)
        
        def _functional_forward(params, buffers, *x):
            return th.func.functional_call(self._base_model, (params, buffers), x)
        
        if self.batched_input:
            input_dims = [0] * self.in_dim
        else:
            input_dims = [None] * self.in_dim
        
        self._vectorized_forward = vmap(
            _functional_forward,
            in_dims=tuple([0, 0] + input_dims)
        )
    
    def parallel_forward(self, *x):
        if self._vectorized_forward is None:
            # raise ValueError("Parallel forward not available. Need at least 2 adaptors.")
            return self.cur_adaptor(*x)
        else:
            return self._vectorized_forward(self._stacked_params, self._stacked_buffers, *x)
    
    def get_selected_tasks(self):
        return self._selected_tasks.copy()

class BaseActor(nn.Module):
    def __init__(self, task2_nagents, args):
        super().__init__()
        self.task2n_agents = task2_nagents
        self.args = args
    
    def init_hidden(self, batch_size, task):
        return th.zeros(1, self.args.rnn_hidden_dim, device=self.args.device).repeat(batch_size * self.task2n_agents[task], 1)

    def save(self, path, name='actor'):
        th.save(self.state_dict(), "{}/{}.th".format(path, name))
        
    def load(self, path, name='actor'):
        missing, unexpected = self.load_state_dict(th.load("{}/{}.th".format(path, name), map_location=lambda storage, loc: storage), strict=False)
        print(f"Missing: {missing}, Unexpected: {unexpected}.")

# ------------------------------ Basic Components ------------------------------

class TemporalAttention(nn.Module):
    def __init__(self, input_dim, attn_dim, output_dim):
        super(TemporalAttention, self).__init__()
        self.input_dim = input_dim
        self.attn_dim = attn_dim
        self.output_dim = output_dim
        
        self.temp_query = nn.Linear(input_dim, attn_dim)
        self.temp_key = nn.Linear(input_dim, attn_dim)
        self.temp_value = nn.Linear(input_dim, output_dim)
        self.layer_norm1 = nn.LayerNorm(output_dim)
        self.layer_norm2 = nn.LayerNorm(output_dim)
        self.ffn = FCNet(output_dim, output_dim, hidden_layer=2, hidden_dim=output_dim, use_last_activ=True)
        
    def _attention(self, q, k, v, attn_dim, mask=None):
        """
            q: [bsn, t, attn_dim]
            k: [bsn, t, attn_dim]
            v: [bsn, t, value_dim]
            mask: [bsn, t, 1]
        """
        if len(q.shape) == 2:
            q = q.unsqueeze(1)
        energy = th.bmm(q, k.transpose(1, 2))/(attn_dim ** (1 / 2))
        if mask is not None:
            mask = mask.float()
            attn_mask = 1 - th.bmm(mask, mask.transpose(1, 2)) # [bsn, t, t]
            energy = energy - attn_mask * (1e6)
        score = F.softmax(energy, dim=-1)
        out = th.bmm(score, v).squeeze(1)
        return out
        
    def forward(self, x, mask):
        """
        x: [bs, t, n, d]
        mask: [bs, t, 1]
        """
        bs, t, n, d = x.shape
        
        x_reshaped = x.permute(0, 2, 1, 3).reshape(bs * n, t, d)
        mask = mask.repeat(n, 1, 1)
        
        query = self.temp_query(x_reshaped)  # [bs*n, t, attn_dim]
        key = self.temp_key(x_reshaped)      # [bs*n, t, attn_dim]
        value = self.temp_value(x_reshaped)  # [bs*n, t, output_dim]
        
        attn = self._attention(query, key, value, self.attn_dim, mask) * mask
        out_ln1 = self.layer_norm1(attn + value)  # [bs*n, t, output_dim]
        out_ln2 = self.layer_norm2(self.ffn(out_ln1) + out_ln1)  # [bs*n, t, output_dim]
        out = out_ln2.reshape(bs, n, t, self.output_dim).permute(0, 2, 1, 3) # (bs, t, n, output_dim)
        
        return out

class SkillEncoder(nn.Module):
    def __init__(self, task2input_shape_info, task2decomposer, task2n_agents, surrogate_decomposer, args):
        super().__init__()
        self.rho_dim = args.rho_dim
        
        # self.tau_enc = ObsTrajXformer(task2input_shape_info, task2decomposer, task2n_agents, surrogate_decomposer, args)
        self.state_enc = StateAttnFeatureExtractor(task2decomposer, args)
        self.action_enc = ActionAttnFeatureExtractor(task2decomposer, task2n_agents, surrogate_decomposer, args)
        self.in_proj = nn.Linear(2 * args.entity_embed_dim, args.rnn_hidden_dim)
        
        self.temporal_attn = TemporalAttention(
            input_dim=args.rnn_hidden_dim,
            attn_dim=args.attn_embed_dim,
            output_dim= args.rnn_hidden_dim
        )
        
        # self.out_proj = nn.Linear(2 * args.rnn_hidden_dim, self.rho_dim)
        self.out_proj = FCNet(args.rnn_hidden_dim, self.rho_dim, hidden_layer=2, hidden_dim=self.rho_dim)
        
        print("Skill Encoder Posterior init...")
        count_total_parameters(self, is_concrete=True)
        
    def forward(self, obs_traj, mask, state, action, task):
        # h_tau = self.tau_enc.forward(obs_traj, mask, task) # (bs, t, n, hid)
        h_action = self.action_enc.forward(action, task) # (bs, t, n, hid)
        h_state = self.state_enc.forward(state, task).unsqueeze(2).repeat(1, 1, h_action.size(2), 1) # (bs, t, n, hid)
        
        h_combined = self.in_proj(th.cat([h_state, h_action], dim=-1)) # (bs, t, n, 2*hid)
        h_temporal = self.temporal_attn(h_combined, mask) # (bs, n, 2*hid)
        
        # h = th.cat([h_tau, h_state, h_action], dim=-1) # (bs, n, 3*hid)
        h = h_temporal # (bs, n, 2*hid)
        rho = self.out_proj(h) # (bs, n, rho_dim)
        
        return rho
        
    def save(self, path):
        th.save(self.state_dict(), "{}/skill_encoder.th".format(path))
        
    def load(self, path):
        missing, unexpected = self.load_state_dict(th.load("{}/skill_encoder.th".format(path), map_location=lambda storage, loc: storage), strict=False)
        print(f"Missing: {missing}, Unexpected: {unexpected}.")

class SingleTaskActor(BaseActor):
    def __init__(self, task2input_shape_info, task2decomposer, task2n_agents, surrogate_decomposer, args):
        super().__init__(task2n_agents, args)
        
        inputs = [task2input_shape_info, task2decomposer, task2n_agents, surrogate_decomposer, args]
        
        self.enc = ObsAttnGRUBase(*inputs)
        self.dec = OutputHead(surrogate_decomposer, args, input_dim=args.rnn_hidden_dim + 1)
        
        print("Single task actor init...")
        count_total_parameters(self, is_concrete=True)
        
    def forward(self, obs, cond, hidden_state, task):
        # obs (+last_act +agent_id), shape=[bs, T, n, d_n] -> [bsTn, n_act|1]; cond: shape=obs.shape[:-1] + [d_cond]
        shape = obs.shape
        
        if cond is None:
            cond = 1 * th.ones(obs.shape[1], 1, device=obs.device)
        cond = cond.reshape(-1, 1)
        # print(cond)
        h, enemy_feats = self.enc(obs, hidden_state, task)
        h_xi = th.cat([h, cond], dim=-1)
        act = self.dec(h_xi, enemy_feats)
        
        act = act.reshape(*shape[:-1], -1)
    
        return h, act

class TrActor(BaseActor):
    def __init__(self, task2input_shape_info, task2decomposer, task2n_agents, surrogate_decomposer, args):
        super().__init__(task2n_agents, args)
        self.policy_mh = MultiHeadBase(adaptor_func=partial(OutputHead, surrogate_decomposer, args, input_dim=args.rnn_hidden_dim), device=args.device, max_head=args.n_reuse_heads, in_dim=2)
        
        inputs = [task2input_shape_info, task2decomposer, task2n_agents, surrogate_decomposer, args]
        self.enc = ObsAttnGRUBase(*inputs)
        
        print("TrActor init...")
        count_total_parameters(self, is_concrete=True)

    def forward(self, obs, hidden_state, task):
        shape = obs.shape
        h, enemy_feats = self.enc(obs, hidden_state, task)

        act = self.policy_mh.cur_adaptor(h, enemy_feats).reshape(*shape[:-1], -1)
        
        return h, act

    def save(self, path, name='actor'):
        th.save(self.state_dict(), "{}/{}.th".format(path, name))
        th.save({
            'policy_mh_task2adaptors': {k: v.state_dict() for k, v in self.policy_mh.task2adaptors.items()},
            'policy_mh_cur_all_task': self.policy_mh.cur_all_task,
        }, "{}/tr_actor_adaptors.th".format(path))
        
    def load(self, path, name='actor'):
        missing, unexpected = self.load_state_dict(th.load("{}/{}.th".format(path, name), map_location=lambda storage, loc: storage), strict=False)
        print(f"Missing: {missing}, Unexpected: {unexpected}.")
        
        adaptor_path = "{}/tr_actor_adaptors.th".format(path)
        if not os.path.exists(adaptor_path):
            print("  No adaptor checkpoint found, adaptors will be initialized at runtime")
            return
        
        adaptor_ckpt = th.load(adaptor_path, map_location=lambda storage, loc: storage)
        for task_name, state_dict in adaptor_ckpt['policy_mh_task2adaptors'].items():
            if task_name in self.policy_mh.task2adaptors:
                self.policy_mh.task2adaptors[task_name].load_state_dict(state_dict)
        self.policy_mh.cur_all_task = adaptor_ckpt['policy_mh_cur_all_task']
        self.policy_mh._build_parallel_forward()

class UnifiedTrActor(BaseActor):
    def __init__(self, task2input_shape_info, task2decomposer, task2n_agents, surrogate_decomposer, args):
        super().__init__(task2n_agents, args)
        self.policy_mh = MultiHeadBase(adaptor_func=partial(OutputHead, surrogate_decomposer, args, input_dim=args.rnn_hidden_dim + args.rho_dim), device=args.device, max_head=args.n_reuse_heads, in_dim=2, batched_input=True)
        self.rho_mh = MultiHeadBase(adaptor_func=partial(FCNet, args.rnn_hidden_dim, self.args.rho_dim, hidden_layer=2, hidden_dim=args.rnn_hidden_dim), device=args.device, max_head=args.n_reuse_heads, in_dim=1) # rho Prior
        
        self.n_reuse_heads = args.n_reuse_heads
        
        inputs = [task2input_shape_info, task2decomposer, task2n_agents, surrogate_decomposer, args]
        
        self.enc = ObsAttnGRUBase(*inputs)
        self.off_dec = OutputHead(surrogate_decomposer, args, input_dim=args.rnn_hidden_dim + args.rho_dim)
        self.off_dec.requires_grad_(False)
        
        print("UnifiedTrActor init...")
        count_total_parameters(self, is_concrete=True)

    def forward(self, obs, rho, hidden_state, task, all_forward=False):
        shape = obs.shape
        h, enemy_feats = self.enc(obs, hidden_state, task)
        rho_infer = self.rho_mh.cur_adaptor(h)
        
        if rho is None: # eval
            rho = rho_infer
        rho = rho.reshape(-1, rho.shape[-1])

        rho_gb = F.gumbel_softmax(rho, dim=-1, tau=1, hard=True)
        if self.args.abla_no_skill:
            rho_gb = th.zeros_like(rho_gb.detach())
        h_cond = th.cat([h, rho_gb], dim=-1)
        act = self.policy_mh.cur_adaptor(h_cond, enemy_feats).view(*shape[:-1], -1)
        
        if not all_forward:
            old_acts = None
            off_act = None
        else:
            with th.no_grad():
                rho_old = self.rho_mh.parallel_forward(h)
                h_cond_old = th.cat([h.unsqueeze(0).expand(rho_old.shape[0], -1, -1), rho_old], dim=-1)
                enemy_feats_old = enemy_feats.unsqueeze(0).expand(rho_old.shape[0], -1, -1, -1)
                old_acts_flat = self.policy_mh.parallel_forward(h_cond_old, enemy_feats_old)
                old_acts = old_acts_flat.view(old_acts_flat.shape[0], *shape[:-1], -1)
                    
            off_act = self.off_dec(h_cond, enemy_feats).view(*shape[:-1], -1)

        rho_infer = rho_infer.view(*shape[:-1], -1)
        return h, act, old_acts, off_act, rho_infer

    def save(self, path, name='actor'):
        th.save(self.state_dict(), "{}/{}.th".format(path, name))
        th.save({
            'policy_mh_task2adaptors': {k: v.state_dict() for k, v in self.policy_mh.task2adaptors.items()},
            'policy_mh_cur_all_task': self.policy_mh.cur_all_task,
            'rho_mh_task2adaptors': {k: v.state_dict() for k, v in self.rho_mh.task2adaptors.items()},
            'rho_mh_cur_all_task': self.rho_mh.cur_all_task,
        }, "{}/unified_tr_actor_adaptors.th".format(path))
        
    def load(self, path, name='actor'):
        missing, unexpected = self.load_state_dict(th.load("{}/{}.th".format(path, name), map_location=lambda storage, loc: storage), strict=False)
        print(f"Missing: {missing}, Unexpected: {unexpected}.")
        
        adaptor_path = "{}/unified_tr_actor_adaptors.th".format(path)
        if not os.path.exists(adaptor_path):
            print("  No adaptor checkpoint found, adaptors will be initialized at runtime")
            return
        
        adaptor_ckpt = th.load(adaptor_path, map_location=lambda storage, loc: storage)
        for task_name, state_dict in adaptor_ckpt['policy_mh_task2adaptors'].items():
            if task_name in self.policy_mh.task2adaptors:
                print(f"====== Loaded policy_mh adaptor for task {task_name}")
                self.policy_mh.task2adaptors[task_name].load_state_dict(state_dict)
        self.policy_mh.cur_all_task = adaptor_ckpt['policy_mh_cur_all_task']
        self.policy_mh._build_parallel_forward()
        
        for task_name, state_dict in adaptor_ckpt['rho_mh_task2adaptors'].items():
            if task_name in self.rho_mh.task2adaptors:
                print(f"====== Loaded rho_mh adaptor for task {task_name}")
                self.rho_mh.task2adaptors[task_name].load_state_dict(state_dict)
        self.rho_mh.cur_all_task = adaptor_ckpt['rho_mh_cur_all_task']
        self.rho_mh._build_parallel_forward()
    