import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
import math


class GatedAttention(nn.Module):
    def __init__(self, L, D, dropout=None, n_cls=1):
        """Gated attention module. 
        Args:
            L (int): Input feature dimension.
            D (int): Hidden layer feature dimension.
            dropout (float, optional): Dropout. Defaults to None.
            n_cls (int, optional): Number of output classes. Defaults to 1.
        """        
        super(GatedAttention, self).__init__()
        self.attention_a = [nn.Linear(L, D), nn.Tanh(), nn.Dropout(dropout)] if dropout is not None else [nn.Linear(L, D), nn.Tanh()]
        self.attention_b = [nn.Linear(L, D), nn.Sigmoid(), nn.Dropout(dropout)] if dropout is not None else [nn.Linear(L, D), nn.Sigmoid()]
        self.attention_a = nn.Sequential(*self.attention_a)
        self.attention_b = nn.Sequential(*self.attention_b)
        self.attention_c = nn.Linear(D, n_cls)

    def forward(self, x):
        a = self.attention_a(x)
        b = self.attention_b(x)
        A = a.mul(b)
        A = self.attention_c(A) 
        return A, x



class Ours(nn.Module):
    def __init__(self, in_feat_dim, hidden_feat_dim=256, out_feat_dim=512, PE_dim=64, dropout=None, k_sample=12, k_sigma=0.002, n_cls=3,
                n_mag=3, use_HF=False, use_LF=False, use_graph=False, use_pool=False, pooling_ratio=0.8,
                LapPE_k=64, RWPE_length=16, use_LSPE=False, gnn_type='gat', n_layers=3, use_two_branches=False):
        """
        Args:
            in_feat_dim (int): Input feature dimension.
            hidden_feat_dim (int, optional): Hidden layer feature dimension. Defaults to 256.
            out_feat_dim (int, optional): Output feature dimension. Defaults to 512.
            dropout (float, optional): Dropout. Defaults to None.
            k_samples (int, optional): Number of samples (k) to zoom-in at next higher magnification. Defaults to 12.
            k_sigma (float, optional): Perturbation sigma. Defaults to 2e-3.
            n_cls (int, optional): Number of output classes. Defaults to 3.
        """        
        super(Ours, self).__init__()
        self.n_mag = n_mag
        if self.n_mag > 3:
            raise ValueError(" Maximum number of magnifications supported is 3.")
        self.activation = nn.LeakyReLU()
        self.use_HF = use_HF
        self.use_LF = use_LF
        self.use_graph = use_graph
        self.LapPE_k = LapPE_k
        self.RWPE_length = RWPE_length
        self.use_LSPE = use_LSPE
        self.n_layers = n_layers
        self.use_pool = use_pool
        self.pooling_ratio = pooling_ratio
        self.use_two_branches = use_two_branches
        self.gnn_type = gnn_type

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.k_sample = k_sample
        self.k_sigma = k_sigma
        self.n_cls = n_cls

        if self.RWPE_length>0 or self.LapPE_k>0:
            PE_dim =  PE_dim  #LapPE_k + RWPE_length
        else:
            PE_dim = 0
        self.PE_dim = PE_dim
        classfication_heads_dim = out_feat_dim
        if self.use_HF: classfication_heads_dim += out_feat_dim
        if self.use_LF: classfication_heads_dim += out_feat_dim
        
        self.fc_low_mag = nn.Sequential(nn.Linear(in_feat_dim, out_feat_dim), nn.LayerNorm(out_feat_dim), self.activation, nn.Dropout(dropout))

        if LapPE_k>0 or RWPE_length>0:
            self.fc_PE = nn.Sequential(nn.Linear(LapPE_k+RWPE_length, PE_dim), nn.LayerNorm(PE_dim), nn.Tanh())
            if self.use_LSPE:
                self.gnn_PE = get_gnn(gnn_type,n_layers,PE_dim, PE_dim, num_heads=4, dropout=dropout)
        self.ga_low_mag = GatedAttention(L=classfication_heads_dim+PE_dim, D=hidden_feat_dim, dropout=dropout, n_cls=1)
        if self.n_mag > 1:
            self.fc_mid_mag = nn.Sequential(nn.Linear(in_feat_dim, out_feat_dim), nn.LayerNorm(out_feat_dim), self.activation,nn.Dropout(dropout))
            self.ga_mid_mag = GatedAttention(L=classfication_heads_dim+PE_dim, D=hidden_feat_dim, dropout=dropout, n_cls=1)
        if self.n_mag == 3:
            self.fc_high_mag = nn.Sequential(nn.Linear(in_feat_dim, out_feat_dim),nn.LayerNorm(out_feat_dim), self.activation,nn.Dropout(dropout))
            self.ga_high_mag = GatedAttention(L=classfication_heads_dim+PE_dim, D=hidden_feat_dim, dropout=dropout, n_cls=1)
        if self.use_HF:
            self.fc_HF = nn.Sequential(nn.Linear(in_feat_dim, out_feat_dim),nn.LayerNorm(out_feat_dim), self.activation,nn.Dropout(dropout))
            if self.n_mag >1:
                self.fc_x2_HF = nn.Sequential(nn.Linear(in_feat_dim, out_feat_dim),nn.LayerNorm(out_feat_dim), self.activation,nn.Dropout(dropout))
            if self.n_mag ==3:
                self.fc_x3_HF = nn.Sequential(nn.Linear(in_feat_dim, out_feat_dim),nn.LayerNorm(out_feat_dim), self.activation,nn.Dropout(dropout))
        if self.use_LF:
            self.fc_LF = nn.Sequential(nn.Linear(in_feat_dim, out_feat_dim),nn.LayerNorm(out_feat_dim), self.activation,nn.Dropout(dropout))
            if self.n_mag >1:
                self.fc_x2_LF = nn.Sequential(nn.Linear(in_feat_dim, out_feat_dim),nn.LayerNorm(out_feat_dim), self.activation,nn.Dropout(dropout))
            if self.n_mag ==3:
                self.fc_x3_LF = nn.Sequential(nn.Linear(in_feat_dim, out_feat_dim),nn.LayerNorm(out_feat_dim), self.activation,nn.Dropout(dropout))

        if use_two_branches:
            self.fc_low_mag_base = nn.Sequential(nn.Linear(in_feat_dim, out_feat_dim), nn.LayerNorm(out_feat_dim), self.activation, nn.Dropout(dropout))
            self.classify_head = nn.ModuleList([nn.Linear(out_feat_dim,1)])
            self.classify_head.append(nn.Linear(n_mag*classfication_heads_dim, n_cls-1))
        else:
            self.classify_head = nn.Linear(n_mag*classfication_heads_dim, n_cls)

        
    def forward(self, data):
        ###################################################################################################################
        ############################################ get data ############################################################# 
        x1, x1_coords = data['low_mag_feats'], data['low_mag_coords']
        if self.n_mag>1:
            x2, x2_coords = data['mid_mag_feats'], data['mid_mag_coords']
        if self.n_mag==3:
            x3, x3_coords = data['high_mag_feats'], data['high_mag_coords']
        if self.use_HF:
            x1_HF = data['low_mag_feats_HF']
            if self.n_mag>1:
                x2_HF = data['mid_mag_feats_HF']
            if self.n_mag==3:
                x3_HF = data['high_mag_feats_HF']
        if self.use_LF:
            x1_LF = data['low_mag_feats_LF']
            if self.n_mag>1:
                x2_LF = data['mid_mag_feats_LF']
            if self.n_mag==3:
                x3_LF = data['high_mag_feats_LF']

        if self.use_graph:
            if self.use_HF:
                HF_edge_index, HF_edge_attr = data['HF_edge_index'][0], data['HF_edge_attr']
            if self.use_LF:
                LF_edge_index, LF_edge_attr = data['LF_edge_index'][0], data['LF_edge_attr']
        if self.LapPE_k > 0:
            LapPE = data['LapPE']
        if self.RWPE_length > 0:
            RWPE = data['RWPE']
        
        ###################################################################################################################
        ############################################ projection ###########################################################
        fc_x1 = self.fc_low_mag(x1)                     # [b, N_1, out_dim]
        if self.use_HF:
            fc_x1_HF = self.fc_HF(x1_HF)                     # [b, N_1, out_dim]
        if self.use_LF:
            fc_x1_LF = self.fc_LF(x1_LF)                     # [b, N_1, out_dim]
        if self.RWPE_length>0 and self.LapPE_k>0:
            PE = self.fc_PE(torch.cat([LapPE, RWPE], dim=-1))
        elif self.RWPE_length>0:
            PE = self.fc_PE(RWPE)
        elif self.LapPE_k>0:
            PE = self.fc_PE(LapPE)
        if (self.RWPE_length>0 or self.LapPE_k>0) and self.use_LSPE:
            PE = PE.squeeze(dim=0)
            for gnn_PE in self.gnn_PE:
                if self.gnn_type == 'mlp':
                    PE = torch.tanh(gnn_PE(PE) + PE)
                else:
                    PE = torch.tanh(gnn_PE(PE, HF_edge_index) + PE)
            PE = PE.unsqueeze(dim=0)
            
        ###################################################################################################################
        ############################################ embedding ############################################################
        input_feat = fc_x1
        if self.use_HF:
            input_feat = torch.cat([input_feat, fc_x1_HF], dim=-1)
        if self.use_LF:
            input_feat = torch.cat([input_feat, fc_x1_LF], dim=-1)
        if self.RWPE_length>0 or self.LapPE_k>0:
            input_feat_PE = torch.cat([input_feat, PE], dim=-1)
            A_1, input_feat_PE = self.ga_low_mag(input_feat_PE)                                            # [b, N_1, 1], [b, N_1, out_dim]
        else:
            A_1, input_feat = self.ga_low_mag(input_feat)                                            # [b, N_1, 1], [b, N_1, out_dim]
        A_1 = A_1.permute(0, 2, 1)                                                                      # [b, 1, N_1]
        A_1 = F.softmax(A_1, dim=-1)                                                                    # [b, 1, N_1]
        # attention pooling
        M = A_1 @ input_feat                                                                                  # [b, 1, out_dim]
        if self.n_mag>1:
            A = A_1.squeeze()
            
            if self.use_pool:
                k_sample_1 = math.ceil(self.pooling_ratio*x1.shape[1])
            else:
                k_sample_1 = min(self.k_sample, x1.shape[1])
            scores_1, perm_1 = torch.topk(A, k_sample_1, dim=0)
            idx_1 = perm_1 * (x2.shape[1]//x1.shape[1])
            idx_1 = torch.cat([idx_1 + i for i in range(x2.shape[1]//x1.shape[1])], dim=0)
            
            fc_x2 = self.fc_mid_mag(x2[:,idx_1,:])
            input_feat_x2 = fc_x2
            if self.use_HF:
                fc_x2_HF = self.fc_x2_HF(x2_HF[:,idx_1,:])
                input_feat_x2 = torch.cat([input_feat_x2, fc_x2_HF], dim=-1)    
            if self.use_LF:
                fc_x2_LF = self.fc_x2_LF(x2_LF[:,idx_1,:])
                input_feat_x2 = torch.cat([input_feat_x2, fc_x2_LF], dim=-1)
            
            if self.RWPE_length>0 or self.LapPE_k>0:
                PE = PE[:,perm_1,:]
                PE = PE.repeat(1,1,x2.shape[1]//x1.shape[1]).reshape(PE.shape[0], PE.shape[1]*(x2.shape[1]//x1.shape[1]), PE.shape[2])
                input_feat_PE_x2 = torch.cat([input_feat_x2, PE], dim=-1)
                A_2, input_feat_PE_x2 = self.ga_mid_mag(input_feat_PE_x2)                       # [b, N_1, 1], [b, N_1, out_dim]
            else:
                A_2, input_feat_x2 = self.ga_mid_mag(input_feat_x2)                             # [b, N_1, 1], [b, N_1, out_dim]
            
            A_2 = A_2.permute(0, 2, 1)                                                                      # [b, 1, N_1]
            A_2 = F.softmax(A_2, dim=-1)                                                                    # [b, 1, N_1]
            M_2 = A_2 @ input_feat_x2                                                                       # [b, 1, out_dim]
            M = torch.cat([M, M_2],dim=-1)

        if self.n_mag ==3:
            A = A_2.squeeze()
            if self.use_pool:
                k_sample_2 = math.ceil(self.pooling_ratio*fc_x2.shape[1])
            else:
                k_sample_2 = min(self.k_sample, fc_x2.shape[1])
            scores_2, perm_2 = torch.topk(A, k_sample_2, dim=0)
            idx_2 = idx_1[perm_2]
            idx_2 = idx_2 * (x3.shape[1]//x2.shape[1])
            idx_2 = torch.cat([idx_2 + i for i in range(x3.shape[1]//x2.shape[1])], dim=0)
            fc_x3 = self.fc_high_mag(x3[:,idx_2,:])
            input_feat_x3 = fc_x3
            if self.use_HF:
                fc_x3_HF = self.fc_x3_HF(x3_HF[:,idx_2,:])
                input_feat_x3 = torch.cat([input_feat_x3, fc_x3_HF], dim=-1)    
            if self.use_LF:
                fc_x3_LF = self.fc_x3_LF(x3_LF[:,idx_2,:])
                input_feat_x3 = torch.cat([input_feat_x3, fc_x3_LF], dim=-1)
            
            if self.RWPE_length>0 or self.LapPE_k>0:
                PE = PE[:,perm_2,:]
                PE = PE.repeat(1,1,x3.shape[1]//x2.shape[1]).reshape(PE.shape[0], PE.shape[1]*(x3.shape[1]//x2.shape[1]), PE.shape[2])

                input_feat_PE_x3 = torch.cat([input_feat_x3, PE], dim=-1)
                A_3, input_feat_PE_x3 = self.ga_high_mag(input_feat_PE_x3)                       # [b, N_1, 1], [b, N_1, out_dim]
            else:
                A_3, input_feat_x3 = self.ga_high_mag(input_feat_x3)                             # [b, N_1, 1], [b, N_1, out_dim]

            A_3 = A_3.permute(0, 2, 1)                                                                      # [b, 1, N_1]
            A_3 = F.softmax(A_3, dim=-1)                                                                    # [b, 1, N_1]
            M_3 = A_3 @ input_feat_x3                                                                       # [b, 1, out_dim]

            M = torch.cat([M, M_3],dim=-1)
            

        ###################################################################################################################
        ############################################# classifier head #####################################################
        if self.use_two_branches:
            fc_x1_base = self.fc_low_mag_base(x1)                     # [b, N_1, out_dim]
            M_base = fc_x1_base.mean(dim=1, keepdim=True)
            logits = torch.cat([self.classify_head[0](M_base.squeeze(dim=1)), self.classify_head[1](M.squeeze(dim=1))], dim=-1)            # [b, n_cls]
        else:
            logits = self.classify_head(M.squeeze(dim=1))
        Y_hat = torch.topk(logits, 1, dim = -1)[-1]
        Y_prob = F.softmax(logits, dim = -1)

        return logits, Y_hat, Y_prob
