import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from torch.nn import TransformerEncoderLayer
from torch import Tensor
from typing import Optional

class MLPHead(nn.Module):
    def __init__(self, in_channels, mlp_hidden_size, projection_size):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(in_channels, mlp_hidden_size),
            nn.BatchNorm1d(mlp_hidden_size),
            nn.ReLU(inplace=True),
            nn.Linear(mlp_hidden_size, projection_size)
        )

    def forward(self, x):
        return self.net(x)

class FT(nn.Module):
    def __init__(self,feature_dim,depth,heads,dim_feedforward):
        super().__init__()
        self.encoder = BNTF(feature_dim,depth,heads,dim_feedforward)
        self.g2 = nn.Sequential(
            nn.Linear(8 * 100, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(),
            nn.Linear(256, 32),
            nn.BatchNorm1d(32),
            nn.LeakyReLU(),
            nn.Linear(32,2)
            )
    def forward(self,img):
        bz, _, _, = img.shape

        for atten in self.encoder.attention_list:
            img = atten(img)

        node_feature = self.encoder.dim_reduction(img)
        node_feature = node_feature.reshape((bz, -1))
        node_feature = F.leaky_relu(node_feature)
        node_feature = self.g2(node_feature)
        return node_feature

class InterpretableTransformerEncoder(TransformerEncoderLayer):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=F.relu,
                 layer_norm_eps=1e-5, batch_first=False, norm_first=False,
                 device=None, dtype=None) -> None:
        super().__init__(d_model, nhead, dim_feedforward, dropout, activation,
                         layer_norm_eps, batch_first, norm_first, device, dtype)
        self.attention_weights: Optional[Tensor] = None

    def _sa_block(self, x: Tensor,
                  attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
        x, weights = self.self_attn(x, x, x,
                                    attn_mask=attn_mask,
                                    key_padding_mask=key_padding_mask,
                                    need_weights=True,
                                    average_attn_weights=True)
        self.attention_weights = weights
        return self.dropout1(x)

    def get_attention_weights(self) -> Optional[Tensor]:
        return self.attention_weights

class ExplainableBNTF(nn.Module):
    def __init__(self,feature_dim,depth,heads,dim_feedforward):
        super().__init__()
        self.num_patches = 100#112

        self.attention_list = nn.ModuleList()
        self.node_num = self.num_patches
        for _ in range(int(depth)):
            self.attention_list.append(
                InterpretableTransformerEncoder(d_model=self.node_num, nhead=int(heads), dim_feedforward=1024, 
                                        batch_first=True)
            )
        self.dim_reduction = nn.Sequential(
            nn.Linear(self.node_num, 8),
            nn.LeakyReLU()
        )

        final_dim = 8 * self.node_num

        self.g = MLPHead(final_dim, final_dim * 2, feature_dim)
        
    def forward(self,img,forward_with_mlp=True):
        bz, _, _, = img.shape
        weights = []
        for atten in self.attention_list:
            img = atten(img)
            atten_weights = atten.get_attention_weights()
            weights.append(atten_weights)
        if forward_with_mlp is not True:
            return img
        node_feature = self.dim_reduction(img)
        node_feature = node_feature.reshape((bz, -1))
        node_feature = self.g(node_feature)
        return node_feature, weights

class BNTF(nn.Module):
    def __init__(self,feature_dim,depth,heads,dim_feedforward):
        super().__init__()
        self.num_patches = 100 #112

        self.attention_list = nn.ModuleList()
        self.node_num = self.num_patches
        for _ in range(int(depth)):
            self.attention_list.append(
                TransformerEncoderLayer(d_model=self.node_num, nhead=int(heads), dim_feedforward=1024, 
                                        batch_first=True)
            )
        self.dim_reduction = nn.Sequential(
            nn.Linear(self.node_num, 8),
            nn.LeakyReLU()
        )

        final_dim = 8 * self.node_num

        self.g = MLPHead(final_dim, final_dim * 2, feature_dim)
        
    def forward(self,img,forward_with_mlp=True):
        bz, _, _, = img.shape

        for atten in self.attention_list:
            img = atten(img)
        if forward_with_mlp is not True:
            return img
        node_feature = self.dim_reduction(img)
        node_feature = node_feature.reshape((bz, -1))
        node_feature = self.g(node_feature)
        return node_feature