import torch
from torch import nn

from einops import rearrange
from opencood.models.sub_modules.relative_pose_embedding import HRPE, RPE
        
class Agent_wise_Attention(nn.Module):
    """
    Vanilla CAV attention.
    """
    def __init__(self, dim, heads, dim_head=64, dropout=0.1, embed_config=None):
        super(Agent_wise_Attention, self).__init__()
        inner_dim = heads * dim_head
        
        self.pre_norm = nn.LayerNorm(dim)
        
        self.is_embed = False
        self.embed_config = embed_config
        if self.embed_config is not None:
            self.is_embed = embed_config['is_embed']
            self.embed_type = embed_config['embed_type']
        
        if self.is_embed:
            if self.embed_type == 'hetero':
                self.emb = HRPE(dim, embed_config['learnable'], embed_config['per_degree'], embed_config['per_dist'])
            elif self.embed_type == 'normal':
                self.emb = RPE(dim, embed_config['per_degree'], embed_config['per_dist'])

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
        )
        
        self.to_s_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
        )
        
    def forward(self, x, mask, infra, init_flag=False, hrpe_comp=False):  
        B, L, _, _, _ = x.shape
        x = self.pre_norm(x)
        
        if init_flag and self.is_embed:
        # if self.is_embed:
            if self.embed_type == 'hetero':
                x = self.emb(x, infra, hrpe_comp)
            elif self.embed_type == 'normal':
                x = self.emb(x, hrpe_comp)
                         
        # x: (B, L, H, W, C) -> (B, H, W, L, C)
        x = x.permute(0, 2, 3, 1, 4)    
        # mask: (B, 1, H, W, L, 1)
        mask = mask.unsqueeze(1)
        # qkv: [(B, H, W, L, C_inner) *3]
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        # q: (B, M, H, W, L, C)
        q, k, v = map(lambda t: rearrange(t, 'b h w l (m c) -> b m h w l c',
                                          m=self.heads), qkv) 

        # attention, (B, M, H, W, L, L)
        att_map = torch.einsum('b m h w i c, b m h w j c -> b m h w i j',
                               q, k) * self.scale        
        # add mask
        att_map = att_map.masked_fill(mask == 0, -float('inf'))
            
        # softmax
        att_map = self.attend(att_map)
        
        # out:(B, M, H, W, L, C_head)
        out = torch.einsum('b m h w i j, b m h w j c -> b m h w i c', att_map,
                           v)
        out = rearrange(out, 'b m h w l c -> b h w l (m c)',
                        m=self.heads)        
        out = self.to_out(out)
        # (B L H W C)
        out = out.permute(0, 3, 1, 2, 4)
        return out
