from typing import List
from ding.torch_utils.network.transformer import Attention
import torch
import torch.nn as nn
import numpy as np
import random

def random_mask(input_tensor, p=0.8):
    random_tensor = torch.rand(input_tensor.shape[:-1], device='cuda')
    mask_tensor = (random_tensor >= p).float()
    masked_tensor = input_tensor * mask_tensor.unsqueeze(-1)
    return masked_tensor

class Block(nn.Module):

    def __init__(self, h_dim: int, max_T: int, n_heads: int, drop_p: float) -> None:
        super().__init__()
        self.attention = Attention(h_dim, h_dim, h_dim, n_heads, nn.Dropout(drop_p))
        self.att_drop = nn.Dropout(drop_p)
        self.mlp = nn.Sequential(
            nn.Linear(h_dim, 4 * h_dim),
            nn.GELU(),
            nn.Linear(4 * h_dim, h_dim),
            nn.Dropout(drop_p),
        )
        self.ln1 = nn.LayerNorm(h_dim)
        self.ln2 = nn.LayerNorm(h_dim)

        mask = torch.tril(torch.ones((max_T, max_T), dtype=torch.bool)).view(1, 1, max_T, max_T)
        # register buffer makes sure mask does not get updated
        # during backpropagation
        self.register_buffer('mask', mask)

    def forward(self, x_agent_mask):
        x, agent_mask = x_agent_mask
        # Attention -> LayerNorm -> MLP -> LayerNorm
        if agent_mask is not None:
            x = x + self.att_drop(self.attention(x, agent_mask.unsqueeze(1).bool().detach()))  # residual
        else:
            x = x + self.att_drop(self.attention(x, self.mask)) 
        x = self.ln1(x)
        x = x + self.mlp(x)  # residual
        x = self.ln2(x)
        return (x, agent_mask)


class MaskMA(nn.Module):

    def __init__(
            self,
            state_dim: int,
            n_blocks: int,
            h_dim: int,
            train_context_len: int,
            eval_context_len: int,
            n_heads: int,
            drop_p: float,
            random_mask_ratio: float = 0,
            max_timestep: int = 8192,
    ) -> None:
        super(MaskMA, self).__init__()
        self.state_dim = state_dim
        self.h_dim = h_dim
        self.train_context_len = train_context_len
        self.eval_context_len = eval_context_len
        self.random_mask_ratio = random_mask_ratio
        # Predeined constants
        self.n_actions_no_attack = 6
        self.unit_type_bits = 10
        self.map_size = 32

        self.max_entity_num = 100
        self.max_timestep = max_timestep
        # transformer blocks
        blocks = [Block(h_dim, max_timestep, n_heads, drop_p) for _ in range(n_blocks)]
        self.transformer = nn.Sequential(*blocks)

        # projection heads (project to embedding)
        self.embed_ln = nn.LayerNorm(h_dim)
        self.embed_timestep = nn.Embedding(max_timestep, h_dim)
        self.embed_alley_enemy = nn.Embedding(2, h_dim)
        self.embed_state = torch.nn.Linear(state_dim, h_dim)
        
        self.no_attack_action = nn.Sequential(
            nn.Linear(h_dim, h_dim // 2),
            nn.ReLU(inplace=True),
            nn.Linear(h_dim // 2, self.n_actions_no_attack),
        )
        self.attack_action = nn.Sequential(
            nn.Linear(2*h_dim, h_dim // 2),
            nn.ReLU(inplace=True),
            nn.Linear(h_dim // 2, 1),
        )

        # prediction heads
        self.predict_rtg = torch.nn.Linear(h_dim, 1)
        self.predict_state = torch.nn.Linear(h_dim, state_dim)
        self.time = 0
    
    def init_cache(self):
        self.time = 0

    def generate_train_mask(self, batch_size, random_mask_ratio, states):
        block = torch.ones(self.timestep_len, self.timestep_len, device='cuda').detach()
        diagonal_blocks = [block] * self.context_len
        diagonal_block_mask = torch.block_diag(*diagonal_blocks).unsqueeze(0).repeat(batch_size,1,1)  # [B,self.timestep_len*self.context_len,self.timestep_len*self.context_len]
        diagonal_block_mask = (diagonal_block_mask + torch.tril(torch.ones((batch_size, self.timestep_len*self.context_len, self.timestep_len*self.context_len), device='cuda'))).bool().detach()
        if random_mask_ratio == 'env':
            pos = states[:,:,:,1 + self.unit_type_bits:1 + self.unit_type_bits + 2]*self.map_size
            disp = pos[:, :, :, None, :] - pos[:, :, None, :, :]
            dist = torch.sqrt(torch.sum(disp * disp, dim=-1))
            sight_mask = dist < 9
            B_indices = torch.arange(self.entity_num).unsqueeze(-1)     # shape: (entity_num, 1)
            B_indices = torch.cat([B_indices[self.agent_num:], B_indices[:self.agent_num]], dim=0)   # shape: (entity_num, 1)

            B_mask = sight_mask[:,:,B_indices, B_indices.T] 
            B = B_mask.shape[0]
            B_mask = torch.cat((torch.ones((B, self.context_len, 1, B_mask.shape[3]), device='cuda'), B_mask), dim=2)
            B_mask = torch.cat((torch.ones((B, self.context_len, B_mask.shape[2], 1), device='cuda'), B_mask), dim=3)
            B_mask = B_mask.permute(0,2,1,3).reshape(B, -1, diagonal_block_mask.shape[2])
            B_mask = B_mask.unsqueeze(1).repeat(1,self.context_len,1,1).reshape(B, -1, diagonal_block_mask.shape[2]).bool()
            diagonal_block_mask = diagonal_block_mask & B_mask

        elif random_mask_ratio > 0:
            mask = torch.rand((batch_size, self.timestep_len*self.context_len, self.timestep_len*self.context_len), device='cuda').detach() >= random_mask_ratio
            diagonal_block_mask = diagonal_block_mask & mask
        return diagonal_block_mask

    def generate_eval_mask(self, env_mask, mask_type):
        if mask_type == 'CE':
            mask = torch.ones((self.timestep_len*self.context_len, self.timestep_len*self.context_len), device='cuda')
            return mask.bool().detach()
        elif mask_type == 'DE':
            mask = env_mask.squeeze(0).squeeze(0)
        B_indices = torch.arange(self.entity_num).unsqueeze(-1).cuda()     # shape: (entity_num, 1)
        B_indices = torch.cat([B_indices[self.agent_num:], B_indices[:self.agent_num]], dim=0)   # shape: (entity_num, 1)

        B_mask = mask[B_indices, B_indices.T] 
        B_mask = torch.cat((torch.ones((1, B_mask.shape[0]), device='cuda'), B_mask), dim=0)
        B_mask = torch.cat((torch.ones((B_mask.shape[0], 1), device='cuda'), B_mask), dim=1)
        if self.time==0:
            B_mask = B_mask.repeat((200, 200))
            self.cache_mask = B_mask
        else:
            self.cache_mask[:, self.time*self.timestep_len:(self.time+1)*self.timestep_len] = B_mask.repeat((200, 1))
        B_mask = self.cache_mask[self.time_base*self.timestep_len:(self.time+1)*self.timestep_len, self.time_base*self.timestep_len:(self.time+1)*self.timestep_len]
        return B_mask.unsqueeze(0).bool().detach()
        

    def forward(self, obs, actions=None, eval=False, eval_mask_type='DE', eval_agent_num=-1):
        B = obs['states'].shape[0]
        entity_num = obs['states'].shape[2]
        if actions is not None:
            agent_num = actions.shape[2]
        elif eval_agent_num != -1:
            agent_num = eval_agent_num
        enemy_num = entity_num-agent_num
        self.timestep_len = entity_num+1
        self.entity_num = entity_num
        self.agent_num = agent_num
        self.enemy_num = enemy_num
        states = obs['states']
        alley_enemy_embedding = self.embed_alley_enemy(torch.from_numpy(np.array([i for i in range(2)])).cuda()).unsqueeze(0).unsqueeze(1) # shape: (1, 1, 2, H)
        if not eval:
            timesteps = obs['timesteps']
            self.context_len = obs['states'].shape[1]
            time_embeddings = self.embed_timestep(timesteps)
            if len(time_embeddings.shape)==1:
                time_embeddings = time_embeddings.unsqueeze(0)
            time_embeddings = time_embeddings.unsqueeze(0).unsqueeze(2)  # shape: (1, T, 1, H)
            state_embeddings = self.embed_state(states)
            state_embeddings[:, :, :agent_num, :] += alley_enemy_embedding[:, :, 0:1, :]
            state_embeddings[:, :, agent_num:, :] += alley_enemy_embedding[:, :, 1:2, :]
            state_embeddings = state_embeddings + time_embeddings
            h = torch.cat((state_embeddings[:, :, agent_num:, :], state_embeddings[:, :, :agent_num, :]), dim=2).reshape(B, self.context_len*self.timestep_len, self.h_dim)
            h = self.embed_ln(h)
            # transformer and prediction
            if self.random_mask_ratio=='random':
                random_mask_ratio = random.random()
            elif self.random_mask_ratio=='env':
                random_mask_ratio = self.random_mask_ratio
            else:
                random_mask_ratio = self.random_mask_ratio
            h, _ = self.transformer((h, self.generate_train_mask(B, random_mask_ratio, states.detach())))

            h = h.reshape(B, self.context_len, self.timestep_len, self.h_dim)

            # get predictions
            action_query = h[..., enemy_num:enemy_num+agent_num, :]
            action_value = torch.cat((h[..., enemy_num:, :],h[..., :enemy_num, :]), dim=2)
            concat_h = torch.cat([action_query.unsqueeze(3).repeat(1, 1, 1, entity_num, 1), action_value.unsqueeze(2).repeat(1, 1, agent_num, 1, 1)], dim=-1)
            attack_logit = self.attack_action(concat_h).squeeze(-1) 
            no_attack_logit = self.no_attack_action(action_query)
            action_preds = torch.cat((no_attack_logit, attack_logit), dim=-1)
        else:
            if self.time==0:
                self.cache_state = torch.zeros((1, 200, entity_num, self.state_dim), device='cuda').float().detach()
            self.cache_state[:, self.time:self.time+1] = states
            if self.time+1>self.eval_context_len:
                self.time_base = self.time+1-self.eval_context_len
                states = self.cache_state[:,self.time_base:self.time+1]
                timesteps = torch.from_numpy(np.array([i for i in range(self.time_base, self.time+1)])).cuda()
            else:
                states = self.cache_state[:, :self.time+1]
                timesteps = torch.from_numpy(np.array([i for i in range(self.time+1)])).cuda()
                self.time_base = 0
            self.context_len = states.shape[1]
            time_embeddings = self.embed_timestep(timesteps).unsqueeze(0).unsqueeze(2)
            state_embeddings = self.embed_state(states)
            state_embeddings[:, :, :agent_num, :] += alley_enemy_embedding[:, :, 0:1, :]
            state_embeddings[:, :, agent_num:, :] += alley_enemy_embedding[:, :, 1:2, :]
            state_embeddings = state_embeddings + time_embeddings
            h = torch.cat((state_embeddings[:, :, agent_num:, :], state_embeddings[:, :, :agent_num, :]), dim=2).reshape(B, -1, self.h_dim)
            h = self.embed_ln(h)
            mask = self.generate_eval_mask(obs['sight_mask'], eval_mask_type) & self.generate_train_mask(B, 0.0, None)
            h, _ = self.transformer((h, mask))
            h = h.reshape(B, self.context_len, self.timestep_len, self.h_dim)

            action_query = h[:,-1:, enemy_num:enemy_num+agent_num, :]
            action_value = torch.cat((h[:,-1:, enemy_num:, :],h[:,-1:, :enemy_num, :]), dim=2)
            if self.concat_fc_action:
                concat_h = torch.cat([action_query.unsqueeze(3).repeat(1, 1, 1, entity_num, 1), action_value.unsqueeze(2).repeat(1, 1, agent_num, 1, 1)], dim=-1)
                attack_logit = self.attack_action(concat_h).squeeze(-1) 
            else:
                attack_logit = torch.bmm(action_query.reshape(B, agent_num, self.h_dim), action_value.reshape(B, entity_num, self.h_dim).permute(0, 2, 1))
                attack_logit = attack_logit.reshape(B, 1, agent_num, entity_num)
            no_attack_logit = self.no_attack_action(action_query)
            action_preds = torch.cat((no_attack_logit, attack_logit), dim=-1)
            action_mask = obs['action_mask']
            tem_action_mask = torch.zeros((1,1,agent_num,action_preds.shape[-1]), device='cuda').detach()
            tem_action_mask[...,:action_mask.shape[-1]] = action_mask
            actions = (action_preds - 1e9 * (~tem_action_mask.bool())).max(-1).indices
            self.time+=1
        return action_preds, actions, self.context_len