import torch.nn as nn
import torch.nn.functional as F
import torch as th
import numpy as np
from utils.transformer import Transformer

class TransformerTransitionDecoder(nn.Module):
    def __init__(self, task2decomposer, args) -> None:
        super(TransformerTransitionDecoder, self).__init__()
        
        self.args = args
        self.task2decomposer = task2decomposer

        ### TODO: 完成Decoder编写

        self.entity_embed_dim = args.entity_embed_dim
        self.attn_embed_dim = args.attn_embed_dim
        self.encoding_dim = args.transition_encoding_dim

        task_0 = list(task2decomposer.keys())[0]
        obs_own_dim = task2decomposer[task_0].own_obs_dim
        obs_en_dim, obs_al_dim = task2decomposer[task_0].obs_nf_en, task2decomposer[task_0].obs_nf_al
        n_actions_no_attack = task2decomposer[task_0].n_actions_no_attack
        
        has_attack_action = n_actions_no_attack != task2decomposer[task_0].n_actions
        self.direct_enemy_action_embedding = False
        if has_attack_action:
            self.direct_enemy_action_embedding = args.direct_enemy_action_embedding

        self.ally_value = nn.Linear(obs_al_dim, self.entity_embed_dim)
        if not self.direct_enemy_action_embedding and has_attack_action:
            self.enemy_value = nn.Linear(obs_en_dim + 1, self.entity_embed_dim)
        else:
            self.enemy_value = nn.Linear(obs_en_dim, self.entity_embed_dim)
        self.own_value = nn.Linear(obs_own_dim, self.entity_embed_dim)

        self.encoding_value = nn.Linear(args.encoding_dim, self.entity_embed_dim)

        self.no_attack_action_value = nn.Linear(n_actions_no_attack, self.entity_embed_dim)
        self.attack_action_value = nn.Linear(1, self.entity_embed_dim)

        # self.time_embed = nn.Embedding(max_seq_len, self.entity_embed_dim)
        max_ally_num = args.max_ally_num
        self.ally_time_embed = nn.Embedding(max_ally_num, self.entity_embed_dim)

        max_enemy_num = args.max_enemy_num
        self.enemy_time_embed = nn.Embedding(max_enemy_num, self.entity_embed_dim)

        self.transformer = Transformer(self.entity_embed_dim, args.head, args.depth, self.entity_embed_dim)

        # self.fc = nn.Linear(self.entity_embed_dim, self.encoding_dim)
        # self.ally_decoder = nn.Linear(self.entity_embed_dim, obs_al_dim)
        # self.enemy_decoder = nn.Linear(self.entity_embed_dim, obs_en_dim)
        # self.own_decoder = nn.Linear(self.entity_embed_dim, obs_own_dim)
        # self.reward_decoder = nn.Linear(self.entity_embed_dim, 1)

        self.hidden_dim = args.encoder_hidden_dim

        self.ally_decoder = nn.Sequential(
            nn.Linear(self.entity_embed_dim * 3, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, obs_al_dim)
        )
        self.enemy_decoder = nn.Sequential(
            nn.Linear(self.entity_embed_dim * 3, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, obs_en_dim)
        )
        self.own_decoder = nn.Sequential(
            nn.Linear(self.entity_embed_dim * 3, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, obs_own_dim)
        )
        self.reward_decoder = nn.Sequential(
            nn.Linear(self.entity_embed_dim * 3, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, 1)
        )
    
    def forward(self, obs, action, task_repre, task):

        assert len(obs.shape)==2, "Input obs should be 2D tensor"
        task_decomposer = self.task2decomposer[task]

        own_obs, enemy_feats, ally_feats = task_decomposer.decompose_obs(obs)
        no_attack_action, attack_action_info, attack_action_ls, _ = task_decomposer.decompose_action(action)

        if not self.direct_enemy_action_embedding and np.prod(attack_action_info.shape) > 0:
            attack_action_info = attack_action_info.transpose(0, 1).unsqueeze(-1)
            enemy_feats = th.cat([th.stack(enemy_feats, dim=0), attack_action_info], dim=-1)
        else:
            enemy_feats = th.stack(enemy_feats, dim=0)
        
        ally_feats = th.stack(ally_feats, dim=0)

        encoding_hidden = self.encoding_value(task_repre).unsqueeze(1)
        own_hidden = self.own_value(own_obs).unsqueeze(1)
        ally_hidden = self.ally_value(ally_feats).permute(1, 0, 2)
        enemy_hidden = self.enemy_value(enemy_feats).permute(1, 0, 2)

        no_attack_action_hidden = self.no_attack_action_value(no_attack_action).unsqueeze(1)

        if self.direct_enemy_action_embedding:
            assert False
            attack_action_feats = th.stack(attack_action_ls, dim=0)
            attack_action_hidden = self.attack_action_value(attack_action_feats).permute(1, 0, 2)

        # ally time embedding
        bs_new, ally_seq_len, _ = ally_hidden.shape
        ally_steps = th.arange(ally_seq_len, device=ally_hidden.device).long()  # (seq_len,)
        ally_step_emb = ally_steps.view(1, ally_seq_len).expand(bs_new, -1)
        ally_step_emb = self.ally_time_embed(ally_step_emb)  # (bs, seq_len, entity_embed_dim)
        ally_hidden = ally_hidden + ally_step_emb

        # enemy time embedding
        bs_new, enemy_seq_len, _ = enemy_hidden.shape
        enemy_steps = th.arange(enemy_seq_len, device=enemy_hidden.device).long()  # (seq_len,)
        enemy_step_emb = enemy_steps.view(1, enemy_seq_len).expand(bs_new, -1)
        enemy_step_emb = self.enemy_time_embed(enemy_step_emb)  # (bs, seq_len, entity_embed_dim)
        enemy_hidden = enemy_hidden + enemy_step_emb
        if self.direct_enemy_action_embedding:
            attack_action_hidden = attack_action_hidden + enemy_step_emb
        
        total_hidden = th.cat([encoding_hidden, own_hidden, enemy_hidden, ally_hidden, no_attack_action_hidden], dim=1)
        
        outputs = self.transformer(total_hidden, None)
        own_length = 1
        enemy_length = enemy_hidden.shape[1]
        ally_length = ally_hidden.shape[1]
        reward_outputs = outputs[:, -1:, :]
        own_outputs = outputs[:, 1:1 + own_length, :]
        enemy_outputs = outputs[:, 1 + own_length:1 + own_length + enemy_length, :]
        ally_outputs = outputs[:, 1 + own_length + enemy_length:1 + own_length + enemy_length + ally_length, :]

        encoding_hidden_enemy = encoding_hidden.repeat(1,enemy_length,1)
        encoding_hidden_ally = encoding_hidden.repeat(1,ally_length,1)
        no_attack_action_hidden_enemy = no_attack_action_hidden.repeat(1,enemy_length,1)
        no_attack_action_hidden_ally = no_attack_action_hidden.repeat(1,ally_length,1)

        reward_outputs = th.cat([encoding_hidden, reward_outputs, no_attack_action_hidden], dim=-1)
        own_outputs = th.cat([encoding_hidden, own_outputs, no_attack_action_hidden], dim=-1)
        enemy_outputs = th.cat([encoding_hidden_enemy, enemy_outputs, no_attack_action_hidden_enemy], dim=-1)
        ally_outputs = th.cat([encoding_hidden_ally, ally_outputs, no_attack_action_hidden_ally], dim=-1)

        reward_outputs = self.reward_decoder(reward_outputs)
        own_outputs = self.own_decoder(own_outputs)
        enemy_outputs = self.enemy_decoder(enemy_outputs)
        ally_outputs = self.ally_decoder(ally_outputs)
        
        return reward_outputs, own_outputs, enemy_outputs, ally_outputs
    
    def get_decoding_loss(self, obs, action, task_repre, task, next_obs, reward, mask):
        reward *= self.args.encoder_reward_scale
        reward_outputs, own_outputs, enemy_outputs, ally_outputs = self.forward(obs, action, task_repre, task)

        task_decomposer = self.task2decomposer[task]
        next_own_obs, next_enemy_feats, next_ally_feats = task_decomposer.decompose_obs(next_obs)
        next_own_feats = next_own_obs.unsqueeze(1)
        next_enemy_feats = th.stack(next_enemy_feats, dim=0).permute(1, 0, 2)
        next_ally_feats = th.stack(next_ally_feats, dim=0).permute(1, 0, 2)
        reward = reward.unsqueeze(1)

        own_obs, enemy_feats, ally_feats = task_decomposer.decompose_obs(obs)
        own_feats = own_obs.unsqueeze(1)
        enemy_feats = th.stack(enemy_feats, dim=0).permute(1, 0, 2)
        ally_feats = th.stack(ally_feats, dim=0).permute(1, 0, 2)

        delta_own_feats = next_own_feats - own_feats
        delta_enemy_feats = next_enemy_feats - enemy_feats
        delta_ally_feats = next_ally_feats - ally_feats

        # own_loss = F.mse_loss(own_outputs, next_own_feats)
        # enemy_loss = F.mse_loss(enemy_outputs, next_enemy_feats)
        # ally_loss = F.mse_loss(ally_outputs, next_ally_feats)
        # reward_loss = F.mse_loss(reward_outputs, reward)
        own_loss = (((own_outputs - delta_own_feats) * mask) ** 2).sum() / mask.sum()
        mask_enemy = mask.unsqueeze(1).repeat(1,enemy_outputs.shape[1],1)
        enemy_loss = (((enemy_outputs - delta_enemy_feats) * mask_enemy) ** 2).sum() / mask_enemy.sum()
        mask_ally = mask.unsqueeze(1).repeat(1,ally_outputs.shape[1],1)
        ally_loss = (((ally_outputs - delta_ally_feats) * mask_ally) ** 2).sum() / mask_ally.sum()
        reward_loss = (((reward_outputs - reward) * mask) ** 2).sum() / mask.sum()

        # mpe: reward_weight 0.2, own_weight 0.05, enemy_weight 0, ally_weight 0
        total_loss = self.args.own_loss_weight * own_loss + \
                      self.args.enemy_loss_weight * enemy_loss + \
                      self.args.ally_loss_weight * ally_loss + \
                      self.args.reward_loss_weight * reward_loss

        return total_loss, own_loss, enemy_loss, ally_loss, reward_loss


class TransformerTransitionRoleDecoder(nn.Module):
    def __init__(self, task2decomposer, args) -> None:
        super(TransformerTransitionRoleDecoder, self).__init__()
        
        self.args = args
        self.task2decomposer = task2decomposer

        ### TODO: 完成Decoder编写

        self.entity_embed_dim = args.entity_embed_dim
        self.attn_embed_dim = args.attn_embed_dim
        self.encoding_dim = args.transition_encoding_dim

        task_0 = list(task2decomposer.keys())[0]
        obs_own_dim = task2decomposer[task_0].own_obs_dim
        obs_en_dim, obs_al_dim = task2decomposer[task_0].obs_nf_en, task2decomposer[task_0].obs_nf_al
        n_actions_no_attack = task2decomposer[task_0].n_actions_no_attack
        
        has_attack_action = n_actions_no_attack != task2decomposer[task_0].n_actions
        self.direct_enemy_action_embedding = False
        if has_attack_action:
            self.direct_enemy_action_embedding = args.direct_enemy_action_embedding

        self.ally_value = nn.Linear(obs_al_dim, self.entity_embed_dim)
        self.enemy_value = nn.Linear(obs_en_dim, self.entity_embed_dim)
        self.own_value = nn.Linear(obs_own_dim, self.entity_embed_dim)

        self.encoding_value = nn.Linear(args.encoding_dim, self.entity_embed_dim)

        # self.time_embed = nn.Embedding(max_seq_len, self.entity_embed_dim)
        max_ally_num = args.max_ally_num
        self.ally_time_embed = nn.Embedding(max_ally_num, self.entity_embed_dim)

        max_enemy_num = args.max_enemy_num
        self.enemy_time_embed = nn.Embedding(max_enemy_num, self.entity_embed_dim)

        self.transformer = Transformer(self.entity_embed_dim, args.head, args.depth, self.entity_embed_dim)

        # self.fc = nn.Linear(self.entity_embed_dim, self.encoding_dim)
        # self.ally_decoder = nn.Linear(self.entity_embed_dim, obs_al_dim)
        # self.enemy_decoder = nn.Linear(self.entity_embed_dim, obs_en_dim)
        # self.own_decoder = nn.Linear(self.entity_embed_dim, obs_own_dim)
        # self.reward_decoder = nn.Linear(self.entity_embed_dim, 1)

        self.hidden_dim = args.encoder_hidden_dim

        self.q_skill = nn.Linear(4*self.entity_embed_dim, n_actions_no_attack)
        self.attack_skill = nn.Linear(2*self.entity_embed_dim, 1)
    
    def forward(self, obs, task_repre, task):

        assert len(obs.shape)==2, "Input obs should be 2D tensor"
        task_decomposer = self.task2decomposer[task]

        own_obs, enemy_feats, ally_feats = task_decomposer.decompose_obs(obs)

        enemy_feats = th.stack(enemy_feats, dim=0)
        
        ally_feats = th.stack(ally_feats, dim=0)

        encoding_hidden = self.encoding_value(task_repre).unsqueeze(1)
        own_hidden = self.own_value(own_obs).unsqueeze(1)
        ally_hidden = self.ally_value(ally_feats).permute(1, 0, 2)
        enemy_hidden = self.enemy_value(enemy_feats).permute(1, 0, 2)

        if self.direct_enemy_action_embedding:
            assert False
            attack_action_feats = th.stack(attack_action_ls, dim=0)
            attack_action_hidden = self.attack_action_value(attack_action_feats).permute(1, 0, 2)

        # ally time embedding
        bs_new, ally_seq_len, _ = ally_hidden.shape
        ally_steps = th.arange(ally_seq_len, device=ally_hidden.device).long()  # (seq_len,)
        ally_step_emb = ally_steps.view(1, ally_seq_len).expand(bs_new, -1)
        ally_step_emb = self.ally_time_embed(ally_step_emb)  # (bs, seq_len, entity_embed_dim)
        ally_hidden = ally_hidden + ally_step_emb

        # enemy time embedding
        bs_new, enemy_seq_len, _ = enemy_hidden.shape
        enemy_steps = th.arange(enemy_seq_len, device=enemy_hidden.device).long()  # (seq_len,)
        enemy_step_emb = enemy_steps.view(1, enemy_seq_len).expand(bs_new, -1)
        enemy_step_emb = self.enemy_time_embed(enemy_step_emb)  # (bs, seq_len, entity_embed_dim)
        enemy_hidden = enemy_hidden + enemy_step_emb
        if self.direct_enemy_action_embedding:
            attack_action_hidden = attack_action_hidden + enemy_step_emb
        
        total_hidden = th.cat([encoding_hidden, own_hidden, enemy_hidden, ally_hidden], dim=1)
        
        outputs = self.transformer(total_hidden, None)
        own_length = 1
        enemy_length = enemy_hidden.shape[1]
        ally_length = ally_hidden.shape[1]

        h = outputs[:, -1:, :]
        encoding_inputs = outputs[:, 0, :]
        base_action_inputs = outputs[:, 1, :]  # th.cat([outputs[:, 0, :], skill], dim=-1)
        obs_enemy = th.max(outputs[:,2:2+enemy_length,:], dim=1)[0]
        obs_ally = th.max(outputs[:,2+enemy_length:2+enemy_length+ally_length,:], dim=1)[0]
        obs_out = th.cat([encoding_inputs, base_action_inputs, obs_enemy, obs_ally], dim=-1)
        q_base = self.q_skill(obs_out)

        if task_decomposer.n_actions_no_attack == task_decomposer.n_actions:
            q = q_base
        else:
            q_attack_list = []
            for i in range(enemy_feats.size(0)):
                if self.args.use_encoding:
                    attack_action_inputs = outputs[:, 2+i, :]
                    attack_action_inputs = th.cat([encoding_inputs, attack_action_inputs], dim=-1)
                else:
                    attack_action_inputs = outputs[:, 1+i, :]
                q_enemy = self.attack_skill(attack_action_inputs)
                # q_enemy_mean = th.mean(q_enemy, 1, True)
                q_attack_list.append(q_enemy)
            q_attack = th.stack(q_attack_list, dim=1).squeeze()

            q = th.cat([q_base, q_attack], dim=-1)
        
        return q
    
    def get_decoding_loss(self, obs, action_long, task_repre, task, mask):
        q = self.forward(obs, task_repre, task)
        label = th.squeeze(action_long, -1).to(th.long)
        criterion = nn.CrossEntropyLoss(reduction='none')
        action_loss = criterion(q, label).reshape(-1, 1)
        assert action_loss.shape == mask.shape
        action_loss = (action_loss * mask).sum() / mask.sum()

        return action_loss