import math
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb

from torch.nn.parameter import Parameter
from torch import nn
from einops import rearrange

class HGNN_conv_layer(nn.Module):
    def __init__(self, in_ft, out_ft, bias=True):
        super(HGNN_conv_layer, self).__init__()
        self.weight = Parameter(torch.FloatTensor(in_ft, out_ft))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_ft))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, x: torch.Tensor, G: torch.Tensor):
        x = x.matmul(self.weight)
        if self.bias is not None:
            x = x + self.bias
        x = G.matmul(x)
        return x

class HGNN_conv(nn.Module):
    def __init__(self, c_in, c_hid, c_out, dropout):
        super(HGNN_conv, self).__init__()
        self.hgc1 = HGNN_conv_layer(c_in, c_hid)
        self.hgc2 = HGNN_conv_layer(c_hid, c_out)
        self.dropout = dropout

    def activations_hook(self, grad):
        self.final_conv_grads = grad

    def forward(self, X, G): # Graph convolution part
        out = F.relu(self.hgc1(X, G))
        out = F.dropout(out, self.dropout)
        out = F.relu(self.hgc2(out, G))

        return out

class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

class SublayerConnection(nn.Module):
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

class AttenLayer(nn.Module):
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(AttenLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        module_list = []
        for i in range(2):
            module_list.append(SublayerConnection(size, dropout))
        self.sublayer = nn.ModuleList(module_list)
        self.size = size

    def forward(self, x):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x)) # Multi-Head Attention
        self.attn_score = self.self_attn.attn_score ##
        return self.sublayer[1](x, self.feed_forward) # Feed-Forward Network

def attention(Q, K, V, dropout=None):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    
    attn_score = F.softmax(scores, dim=-1)
    if dropout is not None:
        attn_score = dropout(attn_score)
    
    self_attn_value = torch.matmul(attn_score, V)
    
    return self_attn_value, attn_score

class MHSA(nn.Module):
    def __init__(self, n_feats, d_model, dropout=0.1):
        super(MHSA, self).__init__()
        assert d_model % n_feats == 0
        self.d_k = d_model // n_feats
        self.h = n_feats

        self.linears_ = []
        for i in range(4):
            self.linears_.append(nn.Linear(d_model, d_model))
        self.linears = nn.ModuleList(self.linears_)

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value):
        n_samples = query.size(0)
        
        Q, K, V = [l(x).view(n_samples, -1, self.h, self.d_k).transpose(1, 2)
                   for l, x in zip(self.linears, (query, key, value))]
        
        self_attn_value, self.attn_score = attention(Q, K, V, dropout=self.dropout)
        
        out = self_attn_value.transpose(1, 2).contiguous().view(n_samples, -1, self.h * self.d_k)
        
        return self.linears[-1](out)
    
class FeedForward(nn.Module):
    def __init__(self, n_node, c_in, c_hid, out_dim, dropout):
        super(FeedForward, self).__init__()
        self.lin1 = nn.Linear(c_hid * c_in, c_hid * c_in) # Readout
        self.lin2 = nn.Linear(c_hid * c_in, out_dim * c_in) # Predictor
        self.dropout = dropout
        
    def forward(self, x):
        out = self.lin1(x)
        out = F.relu(out)
        out = F.dropout(out, self.dropout, training=self.training)
        out = self.lin2(out)
        
        return out

class MLP(nn.Module):
    def __init__(self, c_in, c_hid, out_dim, dropout):
        super(MLP, self).__init__()
        self.lin1 = nn.Linear(c_in, c_hid) # Readout
        self.lin2 = nn.Linear(c_hid, out_dim) # Predictor
        self.dropout = dropout
        
    def forward(self, x):
        out = self.lin1(x)
        out = F.relu(out)
        out = F.dropout(out, self.dropout, training=self.training)
        out = self.lin2(out)
        out = F.relu(out)
        
        return out

class CLF(nn.Module):
    def __init__(self, n_node, c_in, c_hid, c_out, dropout):
        super(CLF, self).__init__()
        self.lin1 = nn.Linear(n_node * c_in, n_node * c_hid // 2) # Readout
        self.lin2 = nn.Linear(n_node * c_hid // 2, c_out) # Predictor
        self.dropout = dropout
        
    def forward(self, x):
        out = rearrange(x, 'b n c -> b (n c)')
        out = self.lin1(out)
        out = F.relu(out)
        out = F.dropout(out, self.dropout, training=self.training)
        out = self.lin2(out)
        
        return out

class MASH(nn.Module):
    def __init__(self, n_node, c_in, c_hid, n_class, dropout=0.5):
        super(MASH, self).__init__()
        self.dropout = dropout
        
        self.init_scale = self.initialize_scales(1,2,3)
        self.n_scale = len(self.init_scale) # 3
        self.scales = nn.ParameterList([nn.Parameter(torch.empty((1)).fill_(self.init_scale[i]), requires_grad=True) for i in range(self.n_scale)])
        self.n_hyper = 16 # 16
        self.d_hyper = 16 # 16
        self.WH = nn.ParameterList([nn.Parameter(torch.randn(self.d_hyper, self.n_hyper), requires_grad=True) for i in range(self.n_scale)])
        self.K = 3 # 3
        self.n_attn_layer = 1
        self.emb_layers_, self.conv_layers_, self.attn_layers_ = [[] for _ in range(3)]
        
        for i in range(self.n_scale):
            self.emb_layers_.append(nn.Linear(c_in, self.d_hyper))
            self.conv_layers_.append(HGNN_conv(c_in, c_hid, c_hid, dropout))
        self.emb_layers = nn.ModuleList(self.emb_layers_)
        self.conv_layers = nn.ModuleList(self.conv_layers_)
        
        c = copy.deepcopy
        attn = MHSA(self.n_scale, self.n_scale * c_hid)
        ffn = FeedForward(n_node, self.n_scale, c_hid, c_hid, dropout)
        for i in range(self.n_attn_layer):
            self.attn_layers_.append(AttenLayer(self.n_scale * c_hid, c(attn), c(ffn), dropout=0.5))
        self.attn_layers = nn.ModuleList(self.attn_layers_)
        
        self.classifier = CLF(n_node, c_hid * self.n_scale , c_hid, n_class, dropout)

    def initialize_scales(self, start, end, count):
        if count < 1:
            raise ValueError("count must be at least 1.")
            
        mid = (start + end) / 2
        if count == 1:
            return [mid]
        elif count == 2:
            return [mid, mid]
        else:
            step = (end - start) / (count - 1)
        return [start + i * step for i in range(count)]

    def compute_low(self, x, eigvec, eigval, t):
        eigval = eigval.type(torch.float)
        eigvec = eigvec.type(torch.float)
        
        K = torch.diag_embed(torch.exp(-1 * eigval * t)) ** 2
        ftr = eigvec @ K @ eigvec.transpose(-1,-2)
        ftrX = ftr @ x
        return ftrX

    def g(self, x):
        x_1 = 1.
        x_2 = 2.
        alpha = 2.
        beta = 2.
        zero = torch.zeros(1).to(x.device)
        eps = 1e-5

        def s(x):
            return -5 + 11*x + -6*(x**2) + x**3

        ftr = torch.where(x < x_1, x_1**(-1*alpha) * torch.pow(x+eps,alpha), zero) \
            + torch.where((x >= x_1) & (x <= x_2), s(x), zero) \
            + torch.where(x > x_2, x_2**beta * torch.pow(x+eps,-1*beta), zero)
        return ftr / ftr.max()

    def compute_mid(self, x, eigvec, eigval, t):
        eigval = eigval.type(torch.float)
        eigvec = eigvec.type(torch.float)
        
        K = self.g(torch.diag_embed(eigval * t)) ** 2
        ftr = eigvec @ K @ eigvec.transpose(-1,-2)
        ftrX = ftr @ x
        return ftrX

    def top_k_hyperedges(self, H, epsilon):
        if epsilon < 1:
            raise ValueError("epsilon must be at least 1")
        topk_values, topk_indices = torch.topk(H, k=epsilon, dim=2)
        result = torch.zeros_like(H)
        result.scatter_(2, topk_indices, topk_values)
        return result

    def generate_H(self, Xs, Ws):
        H = torch.einsum('bnc,cd->bnd', Xs, Ws)
        H = F.softmax(F.relu(H), -1)
        H = self.top_k_hyperedges(H, self.K)
        self.H = H.clone().detach().cpu()
        return H
    
    def generate_G_from_H(self, H):
        eps = 1e-5
        W = torch.ones(self.n_hyper).to(H.device) # [d]
        Dv = torch.sum(H*W, axis=2) # [B, N]
        De = torch.sum(H, axis=1) # [B, E] 
        Dv_inv_sqrt = torch.diag_embed(1/torch.sqrt(Dv+eps)) # [B, N, N]
        De_inv = torch.diag_embed(1/(De+eps)) # [B, E, E]
        W = torch.diag_embed(W) # [E, E]
        G = Dv_inv_sqrt @ H @ W @ De_inv @ H.transpose(-1,-2) @ Dv_inv_sqrt # [B, N, N]
        return G
        
    def forward(self, x, eigvec, eigval):
        """
        S: num of scales
        B: num of samples
        N: num of nodes
        F: dim of features
        E: num of hyperedges
        C: dim of hidden units
        """
        emb = []
        for s in range(self.n_scale):
            if s == 0:
                ftrXs = self.compute_low(x, eigvec, eigval, self.scales[s]) # [B, N, F]
            else:
                ftrXs = self.compute_mid(x, eigvec, eigval, self.scales[s]) # [B, N, F]
            Xs = self.emb_layers[s](ftrXs) # [B, N, C] : filtered feature matrix
            Ws = torch.stack(list(self.WH))[s] # [C, E] : learnable weight matrix
            H = self.generate_H(Xs, Ws) # [B, N, E] : incidence matrix
            G = self.generate_G_from_H(H) # [B, N, N]
            emb.append(self.conv_layers[s](ftrXs, G)) # [B, N, C]

        out = torch.cat(emb,-1) # [B, N, (C S)]
        for n in range(self.n_attn_layer):
            out = self.attn_layers[n](out) # [B, N, (C S)]
        out = self.classifier(out) # [B, n_class]
        
        return out