import torch.nn as nn
import torch.nn.functional as F
import torch as th
import numpy as np
from utils.transformer import Transformer
from modules.encoders.transition_encoder import TransformerTransitionEncoder, TransformerTransitionRoleEncoder

class TransformerTemporalEncoder(nn.Module):
    def __init__(self, task2decomposer, args, is_club=False) -> None:
        super(TransformerTemporalEncoder, self).__init__()
        
        self.args = args
        self.task2decomposer = task2decomposer
        self.is_club = is_club

        self.encoding_dim = args.transition_encoding_dim

        self.transition_encoder = TransformerTransitionEncoder(task2decomposer, args, is_club=self.is_club)
        
        max_step_num = args.max_step_num
        self.step_time_embed = nn.Embedding(max_step_num, self.encoding_dim)

        self.transformer = Transformer(self.encoding_dim, args.head, args.depth, self.encoding_dim)

        self.gate = nn.Linear(self.encoding_dim, 1)
        self.value = nn.Linear(self.encoding_dim, self.encoding_dim)
        self.output = nn.Linear(self.encoding_dim, self.encoding_dim)
    
    def forward(self, obs, action, next_obs, reward, mask, task):

        reward *= self.args.encoder_reward_scale
        
        # [bs, t, n_agents, z_dim]
        transition_encoding = self.transition_encoder(obs, action, next_obs, reward, task)
        bs, max_t, n_agents, _ = transition_encoding.shape
        transition_encoding = transition_encoding.permute(0, 2, 1, 3).reshape(bs*n_agents, max_t, -1)
        assert mask.shape == (bs, max_t, 1)
        mask = mask.unsqueeze(2).repeat(1, 1, n_agents, 1).permute(0, 2, 1, 3).reshape(bs*n_agents, max_t, -1)

        # steps = th.arange(max_t, device=transition_encoding.device).long()
        # bs_new = transition_encoding.shape[0]
        # step_emb = steps.view(1, max_t).expand(bs_new, -1)
        # step_emb = self.step_time_embed(step_emb)
        # transition_encoding = transition_encoding + step_emb
        transition_mask = mask
        transition_encoding = transition_encoding * transition_mask

        if self.args.average_temporal:
            sum_masked = transition_encoding.sum(dim=1)
            count = transition_mask.sum(dim=1)
            temporal_encoding = sum_masked / count
            return temporal_encoding.reshape(bs, n_agents, -1)
        elif self.args.average_gate:
            gate = self.gate(transition_encoding).masked_fill(mask==0, -1e9).transpose(-1, -2)
            attn = F.softmax(gate, dim=-1)
            encoding_value = self.value(transition_encoding)
            output = th.matmul(attn, encoding_value).squeeze(1)
            temporal_encoding = self.output(output)
            return temporal_encoding.reshape(bs, n_agents, -1)

        outputs = self.transformer(transition_encoding, mask)
        temporal_encoding = outputs[:, 0, :]
        temporal_encoding = temporal_encoding.reshape(bs, n_agents, -1)

        return temporal_encoding

class TransformerTemporalRoleEncoder(nn.Module):
    def __init__(self, task2decomposer, args, is_club=False) -> None:
        super(TransformerTemporalRoleEncoder, self).__init__()
        
        self.args = args
        self.task2decomposer = task2decomposer
        self.is_club = is_club

        self.encoding_dim = args.transition_encoding_dim

        self.transition_encoder = TransformerTransitionRoleEncoder(task2decomposer, args, is_club=self.is_club)
        
        max_step_num = args.max_step_num
        self.step_time_embed = nn.Embedding(max_step_num, self.encoding_dim)

        self.transformer = Transformer(self.encoding_dim, args.head, args.depth, self.encoding_dim)

        self.gate = nn.Linear(self.encoding_dim, 1)
        self.value = nn.Linear(self.encoding_dim, self.encoding_dim)
        self.output = nn.Linear(self.encoding_dim, self.encoding_dim)
    
    def forward(self, obs, action, mask, task):
        
        # [bs, t, z_dim]
        transition_encoding = self.transition_encoder(obs, action, task)
        if len(transition_encoding.shape) == 3:
            bs, max_t, _ = transition_encoding.shape
            # transition_encoding = transition_encoding.permute(0, 2, 1, 3).reshape(bs*n_agents, max_t, -1)
            assert mask.shape == (bs, max_t, 1)
            # mask = mask.unsqueeze(2).repeat(1, 1, n_agents, 1).permute(0, 2, 1, 3).reshape(bs*n_agents, max_t, -1)

            # steps = th.arange(max_t, device=transition_encoding.device).long()
            # bs_new = transition_encoding.shape[0]
            # step_emb = steps.view(1, max_t).expand(bs_new, -1)
            # step_emb = self.step_time_embed(step_emb)
            # transition_encoding = transition_encoding + step_emb
            transition_mask = mask
            transition_encoding = transition_encoding * transition_mask

            if self.args.average_temporal:
                sum_masked = transition_encoding.sum(dim=1)
                count = transition_mask.sum(dim=1)
                temporal_encoding = sum_masked / count
                return temporal_encoding.reshape(bs, -1)
            elif self.args.average_gate:
                gate = self.gate(transition_encoding).masked_fill(mask==0, -1e9).transpose(-1, -2)
                attn = F.softmax(gate, dim=-1)
                encoding_value = self.value(transition_encoding)
                output = th.matmul(attn, encoding_value).squeeze(1)
                temporal_encoding = self.output(output)
                return temporal_encoding.reshape(bs, -1)

            outputs = self.transformer(transition_encoding, mask)
            temporal_encoding = outputs[:, 0, :]
            temporal_encoding = temporal_encoding.reshape(bs, -1)
        else:
            bs, max_t, n_agents, _ = transition_encoding.shape
            transition_encoding = transition_encoding.permute(0, 2, 1, 3).reshape(bs*n_agents, max_t, -1)
            assert mask.shape == (bs, max_t, 1)
            mask = mask.unsqueeze(2).repeat(1, 1, n_agents, 1).permute(0, 2, 1, 3).reshape(bs*n_agents, max_t, -1)

            # steps = th.arange(max_t, device=transition_encoding.device).long()
            # bs_new = transition_encoding.shape[0]
            # step_emb = steps.view(1, max_t).expand(bs_new, -1)
            # step_emb = self.step_time_embed(step_emb)
            # transition_encoding = transition_encoding + step_emb
            transition_mask = mask
            transition_encoding = transition_encoding * transition_mask

            if self.args.average_temporal:
                sum_masked = transition_encoding.sum(dim=1)
                count = transition_mask.sum(dim=1)
                temporal_encoding = sum_masked / count
                return temporal_encoding.reshape(bs, n_agents, -1)
            elif self.args.average_gate:
                gate = self.gate(transition_encoding).masked_fill(mask==0, -1e9).transpose(-1, -2)
                attn = F.softmax(gate, dim=-1)
                encoding_value = self.value(transition_encoding)
                output = th.matmul(attn, encoding_value).squeeze(1)
                temporal_encoding = self.output(output)
                return temporal_encoding.reshape(bs, n_agents, -1)

            outputs = self.transformer(transition_encoding, mask)
            temporal_encoding = outputs[:, 0, :]
            temporal_encoding = temporal_encoding.reshape(bs, n_agents, -1)

        return temporal_encoding
