import torch
from torch import nn
from collections import OrderedDict
from timm.models.layers import trunc_normal_
import sys
sys.path.append("../")

class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)

class ResidualCrossAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
        super().__init__()

        self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
        self.ln_1 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        ]))
        self.ln_2 = nn.LayerNorm(d_model)
        self.attn_mask = attn_mask

    def selfattention(self, x: torch.Tensor):
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

    def crossattention(self, x: torch.Tensor, text_feature: torch.Tensor):
        # print('x.shape:', x.shape, 'text_feature.shape:', text_feature.shape)
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
        return self.attn(text_feature, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

    def forward(self, x, text_feature):
        y = self.crossattention(self.ln_1(x), text_feature)
        x = x.mean(dim=1, keepdim=True).expand_as(y) + y
        x = x + self.mlp(self.ln_2(x))
        return x


class MultiframeIntegrationTransformerwithText(nn.Module):
    def __init__(self, T, embed_dim=512):
        super().__init__()
        self.T = T
        transformer_heads = embed_dim // 64
        self.positional_embedding = nn.Parameter(torch.empty(1, T, embed_dim))
        trunc_normal_(self.positional_embedding, std=0.02)
        self.resblock = ResidualCrossAttentionBlock(d_model=embed_dim, n_head=transformer_heads)

        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, (nn.Linear,)):
            trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.zeros_(m.bias)
            nn.init.ones_(m.weight)

    def forward(self, x, text_feature):
        ori_x = x
        x = x + self.positional_embedding
        x = self.resblock(x, text_feature)
        x = x.type(ori_x.dtype) + ori_x.mean(dim=1, keepdim=True).expand_as(x)
        
        return x
