import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F

from utils.embed import polynomial_embed, binary_embed
from utils.transformer import Transformer
class MtUPDeTAgent(nn.Module):
    def __init__(self, task2input_shape_info, task2decomposer, task2n_agents, args):
        super(MtUPDeTAgent, self).__init__()
        self.args = args
        self.task2last_action_shape = {task: task2input_shape_info[task]["last_action_shape"] for task in
                                       task2input_shape_info}
        
        self.task2decomposer = task2decomposer
        self.task2n_agents = task2n_agents

        #### define various dimension information
        ## set attributes
        self.entity_embed_dim = args.policy_entity_embed_dim
        self.attn_embed_dim = args.attn_embed_dim
        # self.task_repre_dim = args.task_repre_dim
        ## get obs shape information
        task_0 = list(task2decomposer.keys())[0]
        obs_own_dim = task2decomposer[task_0].own_obs_dim
        obs_en_dim, obs_al_dim = task2decomposer[task_0].obs_nf_en, task2decomposer[task_0].obs_nf_al
        n_actions_no_attack = task2decomposer[task_0].n_actions_no_attack
        
        has_attack_action = n_actions_no_attack != task2decomposer[task_0].n_actions

        if args.obs_agent_id and args.obs_last_action:
            if has_attack_action:
                ## get wrapped obs_own_dim
                wrapped_obs_own_dim = obs_own_dim + args.id_length + n_actions_no_attack + 1
                ## enemy_obs ought to add attack_action_info
                obs_en_dim += 1
            else:
                wrapped_obs_own_dim = obs_own_dim + args.id_length + n_actions_no_attack
                # wrapped_obs_own_dim = obs_own_dim + args.n_agents + n_actions_no_attack
        else:
            wrapped_obs_own_dim = obs_own_dim

        self.ally_value = nn.Linear(obs_al_dim, self.entity_embed_dim)
        self.enemy_value = nn.Linear(obs_en_dim, self.entity_embed_dim)
        self.own_value = nn.Linear(wrapped_obs_own_dim, self.entity_embed_dim)

        self.separate_role_encoding = getattr(args, "separate_role_encoding", False) and getattr(args, "use_role_encoder", True)
        self.role_encoding_residual = getattr(args, "role_encoding_residual", False) and self.separate_role_encoding
        if getattr(args, "use_role_encoder", True) and not getattr(args, "only_role_encoding", False):
            if not self.separate_role_encoding:
                self.encoding_value = nn.Linear(2 * args.encoding_dim, self.entity_embed_dim)
            else:
                self.encoding_value = nn.Linear(args.encoding_dim, self.entity_embed_dim)
                self.role_encoding_value = nn.Linear(args.encoding_dim, self.entity_embed_dim)
        else:
            self.encoding_value = nn.Linear(args.encoding_dim, self.entity_embed_dim)

        # self.value = nn.Linear(self.entity_embed_dim, self.attn_embed_dim)
        # self.skill_value = nn.Linear(self.skill_dim, self.entity_embed_dim)

        # self.time_embed = nn.Embedding(max_seq_len, self.entity_embed_dim)

        max_ally_num = args.max_ally_num
        self.ally_time_embed = nn.Embedding(max_ally_num, self.entity_embed_dim)

        max_enemy_num = args.max_enemy_num
        self.enemy_time_embed = nn.Embedding(max_enemy_num, self.entity_embed_dim)

        self.transformer = Transformer(self.entity_embed_dim, args.policy_head, args.policy_depth, self.entity_embed_dim)

        self.use_moe = getattr(args, "use_moe", False) and self.separate_role_encoding
        if self.use_moe:
            self.n_experts = getattr(args, "n_experts", 10)
            self.top_k = getattr(args, "top_k", 1)
            self.experts = nn.ModuleList([nn.Linear(5*self.entity_embed_dim, n_actions_no_attack) for _ in range(self.n_experts)])
            self.share_expert = nn.Linear(5*self.entity_embed_dim, n_actions_no_attack)
            self.attack_experts = nn.ModuleList([nn.Linear(3*self.entity_embed_dim, 1) for _ in range(self.n_experts)])
            self.share_attack_expert = nn.Linear(3*self.entity_embed_dim, 1)
            self.gate = nn.Sequential(
                nn.Linear(self.entity_embed_dim, 128),
                nn.ReLU(inplace=True),
                nn.Linear(128, self.n_experts)
            )
            # self.attack_gate = nn.Sequential(
            #     nn.Linear(self.entity_embed_dim, 64),
            #     nn.ReLU(inplace=True),
            #     nn.Linear(64, self.n_experts)
            # )
            

        if args.use_encoding:
            if self.separate_role_encoding:
                if getattr(self.args, "use_max_pooling", True):
                    self.q_skill = nn.Linear(5*self.entity_embed_dim, n_actions_no_attack)
                else:
                    self.q_skill = nn.Linear(3*self.entity_embed_dim, n_actions_no_attack)
                self.attack_skill = nn.Linear(3*self.entity_embed_dim, 1)
            else:
                if getattr(self.args, "use_max_pooling", True):
                    self.q_skill = nn.Linear(4*self.entity_embed_dim, n_actions_no_attack)
                else:
                    self.q_skill = nn.Linear(2*self.entity_embed_dim, n_actions_no_attack)
                self.attack_skill = nn.Linear(2*self.entity_embed_dim, 1)
        else:
            if getattr(self.args, "use_max_pooling", True):
                self.q_skill = nn.Linear(3*self.entity_embed_dim, n_actions_no_attack)
            else:
                self.q_skill = nn.Linear(self.entity_embed_dim, n_actions_no_attack)
            self.attack_skill = nn.Linear(self.entity_embed_dim, 1)
            
        if self.use_moe:
            self.noise_epsilon = nn.Parameter(th.zeros(self.n_experts))
    
    def _noisy_logits(self, logits):
        # logits: (B, K)
        if not self.training:
            return logits
        # add gaussian noise per-sample per-expert, scaled by softplus of noise_epsilon
        scale = F.softplus(self.noise_epsilon)  # (K,)
        noise = th.randn_like(logits) * scale.unsqueeze(0)
        return logits + noise

    def init_hidden(self):
        # make hidden states on the same device as model
        return self.q_skill.weight.new(1, self.entity_embed_dim).zero_()

    def forward(self, inputs, hidden_state, task, task_encoding):
        hidden_state = hidden_state.view(-1, 1, self.entity_embed_dim)
        # get decomposer, last_action_shape and n_agents of this specific task
        task_decomposer = self.task2decomposer[task]
        task_n_agents = self.task2n_agents[task]
        last_action_shape = self.task2last_action_shape[task]

        # decompose inputs into observation inputs, last_action_info, agent_id_info
        obs_dim = task_decomposer.obs_dim
        obs_inputs, last_action_inputs, agent_id_inputs = inputs[:, :obs_dim], \
                                                          inputs[:, obs_dim:obs_dim + last_action_shape], inputs[:,
                                                                                                          obs_dim + last_action_shape:]

        # decompose observation input
        own_obs, enemy_feats, ally_feats = task_decomposer.decompose_obs(
            obs_inputs)  # own_obs: [bs*self.n_agents, own_obs_dim]
        bs = int(own_obs.shape[0] / task_n_agents)

        # embed agent_id inputs and decompose last_action_inputs
        agent_id_inputs = [
            th.as_tensor(binary_embed(i + 1, self.args.id_length, self.args.max_agent), dtype=own_obs.dtype) for i in
            range(task_n_agents)]
        agent_id_inputs = th.stack(agent_id_inputs, dim=0).repeat(bs, 1).to(own_obs.device)
        _, attack_action_info, compact_action_states = task_decomposer.decompose_action_info(last_action_inputs)

        # incorporate agent_id embed and compact_action_states
        if self.args.obs_last_action and self.args.obs_agent_id:
            # if obs_last_action and obs_agent_id, then own_obs should be wrapped
            own_obs = th.cat([own_obs, agent_id_inputs, compact_action_states], dim=-1)
        else:
            own_obs = own_obs
        # own_obs = th.cat([own_obs, agent_id_inputs, compact_action_states], dim=-1)
        
        # incorporate attack_action_info into enemy_feats
        if np.prod(attack_action_info.shape) > 0:
            attack_action_info = attack_action_info.transpose(0, 1).unsqueeze(-1)
            enemy_feats = th.cat([th.stack(enemy_feats, dim=0), attack_action_info], dim=-1)
        else:
            enemy_feats = th.stack(enemy_feats, dim=0)
        ally_feats = th.stack(ally_feats, dim=0)

        # compute key, query and value for attention
        own_hidden = self.own_value(own_obs).unsqueeze(1)
        ally_hidden = self.ally_value(ally_feats).permute(1, 0, 2)
        enemy_hidden = self.enemy_value(enemy_feats).permute(1, 0, 2)
        if not self.separate_role_encoding or getattr(self.args, "only_role_encoding", False):
            encoding_hidden = self.encoding_value(task_encoding).unsqueeze(1)
        else:
            encoding_hidden = self.encoding_value(task_encoding[..., :self.args.encoding_dim]).unsqueeze(1)
            role_encoding_hidden = self.role_encoding_value(task_encoding[..., self.args.encoding_dim:]).unsqueeze(1)
        # skill_hidden = self.skill_value(skill).unsqueeze(1)
        history_hidden = hidden_state

        if getattr(self.args, "use_time_embedding", True):
            # ally time embedding
            bs, ally_seq_len, _ = ally_hidden.shape
            ally_steps = th.arange(ally_seq_len, device=ally_hidden.device).long()  # (seq_len,)
            ally_step_emb = ally_steps.view(1, ally_seq_len).expand(bs, -1)
            ally_step_emb = self.ally_time_embed(ally_step_emb)  # (bs, seq_len, entity_embed_dim)
            ally_hidden = ally_hidden + ally_step_emb
            # enemy time embedding
            bs, enemy_seq_len, _ = enemy_hidden.shape
            enemy_steps = th.arange(enemy_seq_len, device=enemy_hidden.device).long()  # (seq_len,)
            enemy_step_emb = enemy_steps.view(1, enemy_seq_len).expand(bs, -1)
            enemy_step_emb = self.enemy_time_embed(enemy_step_emb)  # (bs, seq_len, entity_embed_dim)
            enemy_hidden = enemy_hidden + enemy_step_emb

        if self.args.use_encoding:
            if not self.separate_role_encoding or getattr(self.args, "only_role_encoding", False):
                if self.args.use_hidden:
                    total_hidden = th.cat([encoding_hidden, own_hidden, enemy_hidden, ally_hidden, history_hidden], dim=1)
                else:
                    total_hidden = th.cat([encoding_hidden, own_hidden, enemy_hidden, ally_hidden], dim=1)
            else:
                if self.args.use_hidden:
                    total_hidden = th.cat([encoding_hidden, role_encoding_hidden, own_hidden, enemy_hidden, ally_hidden, history_hidden], dim=1)
                else:
                    total_hidden = th.cat([encoding_hidden, role_encoding_hidden, own_hidden, enemy_hidden, ally_hidden], dim=1)

            
            outputs = self.transformer(total_hidden, None)

            own_length = 1
            enemy_length = enemy_hidden.shape[1]
            ally_length = ally_hidden.shape[1]

            h = outputs[:, -1:, :]
            if not self.separate_role_encoding:
                encoding_inputs = outputs[:, 0, :]
                base_action_inputs = outputs[:, 1, :]  # th.cat([outputs[:, 0, :], skill], dim=-1)
                obs_enemy = th.max(outputs[:,2:2+enemy_length,:], dim=1)[0]
                obs_ally = th.max(outputs[:,2+enemy_length:2+enemy_length+ally_length,:], dim=1)[0]
                if getattr(self.args, "use_max_pooling", True):
                    obs_out = th.cat([encoding_inputs, base_action_inputs, obs_enemy, obs_ally], dim=-1)
                else:
                    obs_out = th.cat([encoding_inputs, base_action_inputs], dim=-1)
            else:
                encoding_inputs = outputs[:, 0, :]
                role_encoding_inputs = outputs[:, 1, :]
                base_action_inputs = outputs[:, 2, :]  # th.cat([outputs[:, 0, :], skill], dim=-1)
                obs_enemy = th.max(outputs[:,3:3+enemy_length,:], dim=1)[0]
                obs_ally = th.max(outputs[:,3+enemy_length:3+enemy_length+ally_length,:], dim=1)[0]
                if getattr(self.args, "use_max_pooling", True):
                    obs_out = th.cat([encoding_inputs, role_encoding_inputs, base_action_inputs, obs_enemy, obs_ally], dim=-1)
                else:
                    obs_out = th.cat([encoding_inputs, role_encoding_inputs, base_action_inputs], dim=-1)

            if not self.use_moe:
                q_base = self.q_skill(obs_out)
                load_loss = None
            else:
                B = obs_out.shape[0]
                if not self.role_encoding_residual:
                    gates_logits = self.gate(role_encoding_inputs)         # (B, K)
                else:
                    gates_logits = self.gate(role_encoding_hidden[:,0,:])
                gates_logits = self._noisy_logits(gates_logits)
                gates = F.softmax(gates_logits, dim=-1)
                topk_vals, topk_idx = th.topk(gates, self.top_k, dim=-1)
                topk_mask = th.zeros_like(gates)                          # (B, K)
                topk_mask.scatter_(1, topk_idx, 1.0)                         # mark selected experts
                # zero out non-selected probs and renormalize among top-k (so sum over selected == 1)
                gated_probs = gates * topk_mask
                denom = gated_probs.sum(dim=-1, keepdim=True).clamp_min(1e-9)
                gated_probs = gated_probs / denom
                gated_probs_expanded = gated_probs.unsqueeze(-1)
                expert_outputs = th.stack([expert(obs_out) for expert in self.experts], dim=1)
                share_expert_output = self.share_expert(obs_out)
                q_base = th.sum(gated_probs_expanded * expert_outputs, dim=1) + share_expert_output
                importance = gates.sum(dim=0) / B   # (K,)  -- use soft probs (not topk gating) so gradients flow
                # load: fraction of samples that route to expert (using topk mask; non-diff)
                load = topk_mask.sum(dim=0) / B     # (K,)
                # load_loss: encourage importance and load to be correlated / balanced
                # formula: K * sum(importance * load)
                load_loss = self.n_experts * th.sum(importance * load)
        else:
            load_loss = None
            if self.args.use_hidden:
                total_hidden = th.cat([own_hidden, enemy_hidden, ally_hidden, history_hidden], dim=1)
            else:
                total_hidden = th.cat([own_hidden, enemy_hidden, ally_hidden], dim=1)
            
            outputs = self.transformer(total_hidden, None)

            own_length = 1
            enemy_length = enemy_hidden.shape[1]
            ally_length = ally_hidden.shape[1]

            h = outputs[:, -1:, :]
            # encoding_inputs = outputs[:, 0, :]
            base_action_inputs = outputs[:, 0, :]  # th.cat([outputs[:, 0, :], skill], dim=-1)
            obs_enemy = th.max(outputs[:,1:1+enemy_length,:], dim=1)[0]
            obs_ally = th.max(outputs[:,1+enemy_length:1+enemy_length+ally_length,:], dim=1)[0]
            if getattr(self.args, "use_max_pooling", True):
                obs_out = th.cat([base_action_inputs, obs_enemy, obs_ally], dim=-1)
            else:
                obs_out = base_action_inputs
            q_base = self.q_skill(obs_out)

        if task_decomposer.n_actions_no_attack == task_decomposer.n_actions:
            q = q_base
        else:
            q_attack_list = []
            for i in range(enemy_feats.size(0)):
                if self.args.use_encoding:
                    if not self.separate_role_encoding:
                        attack_action_inputs = outputs[:, 2+i, :]
                        attack_action_inputs = th.cat([encoding_inputs, attack_action_inputs], dim=-1)
                    else:
                        attack_action_inputs = outputs[:, 3+i, :]
                        attack_action_inputs = th.cat([encoding_inputs, role_encoding_inputs, attack_action_inputs], dim=-1)
                else:
                    attack_action_inputs = outputs[:, 1+i, :]
                if not self.use_moe:
                    q_enemy = self.attack_skill(attack_action_inputs)
                else:
                    attack_expert_outputs = th.stack([expert(attack_action_inputs) for expert in self.attack_experts], dim=1)
                    share_attack_expert_outputs = self.share_attack_expert(attack_action_inputs)
                    q_enemy = th.sum(gated_probs_expanded * attack_expert_outputs, dim=1) + share_attack_expert_outputs
                # q_enemy_mean = th.mean(q_enemy, 1, True)
                q_attack_list.append(q_enemy)
            q_attack = th.stack(q_attack_list, dim=1).squeeze()

            q = th.cat([q_base, q_attack], dim=-1)
        ### TODO: Finish Top-K MoE
        return q, h, load_loss