import torch
import torch.nn as nn
import torch.nn.functional as F
import numbers


class nconv(nn.Module):
    def __init__(self):
        super(nconv,self).__init__()

    def forward(self, x, A):
        x = torch.einsum('ncwl,vw->ncvl',(x,A))
        return x.contiguous()


class linear_MTGNN(nn.Module):
    def __init__(self,c_in,c_out,bias=True):
        super(linear_MTGNN,self).__init__()
        self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0,0), stride=(1,1), bias=bias)

    def forward(self,x):
        return self.mlp(x)


class linear(nn.Module):
    def __init__(self, c_in, c_out, bias=True):
        super(linear,self).__init__()
        self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0,0), stride=(1,1), bias=bias)

    def forward(self,x):
        x = x.unsqueeze(2)
        x = self.mlp(x)
        return x.squeeze(2)


class CustomLinear(nn.Module):
    def __init__(self, input_dim):
        super(CustomLinear, self).__init__()
        self.weights = nn.Parameter(torch.rand(3, 1))

    def forward(self, x):
        weight = torch.softmax(self.weights, dim=0).to(x.device)
        return torch.sum(x*weight, dim=-2)
    

class GraphConv(nn.Module):
    def __init__(self, corr, high_correlated_count, node_num:int=21, d_node:int=48, top_k:int=8,
                 tanh_alpha:float=3, device:str="cuda:0") -> None:
        super(GraphConv, self).__init__()
        if corr is not None:
            self.correlated_matrix = torch.from_numpy(corr)
            self.correlated_count = torch.from_numpy(high_correlated_count)
        self.node_num = node_num
        self.d_node = d_node
        self.top_k = top_k
        self.tanh_alpha = tanh_alpha
        self.device = device
        
        # 为节点生成两套嵌入
        self.g_embed1 = nn.Embedding(node_num, d_node)
        self.g_embed2 = nn.Embedding(node_num, d_node)
        self.linear1 = nn.Linear(d_node, d_node)
        self.linear2 = nn.Linear(d_node, d_node)

        # 外部辅助设置
        self.diag_indices = torch.arange(self.correlated_count.size(0))
        self.correlated_matrix[self.diag_indices, self.diag_indices] = 0
        # 选择topk位置
        mask = torch.zeros(node_num, node_num)
        s,t = (self.correlated_matrix + torch.rand_like(self.correlated_matrix)*0.01).topk(self.top_k, 1)
        mask.scatter_(1, t, s.fill_(1).float())
        self.correlated_matrix = self.correlated_matrix*mask
        mask = self.correlated_matrix > 0.8
        self.correlated_matrix[~mask] = 0  # 小于等于0.8的值置为0
        self.correlated_matrix[mask] = 1   # 大于0.8的值置为1
        # self.correlated_matrix = F.normalize(self.correlated_matrix, p=2, dim=-1)
        # mean = self.correlated_matrix.mean()
        # std = self.correlated_matrix.std()
        # self.correlated_matrix = (self.correlated_matrix - mean) / std


    def forward(self, id_list):
        # node_embed1 = self.g_embed1(id_list)
        # node_embed2 = self.g_embed2(id_list)
        # node_embed1 = self.linear1(node_embed1)
        # node_embed2 = self.linear2(node_embed2)
        # node_embed1 = torch.tanh(self.tanh_alpha*node_embed1)
        # node_embed2 = torch.tanh(self.tanh_alpha*node_embed2)
        # A0 = torch.triu(torch.mm(node_embed1, node_embed1.transpose(1,0)), diagonal=1) + \
        #     torch.tril(torch.mm(node_embed2, node_embed2.transpose(1,0)), diagonal=-1)
        # # A0 = torch.mm(node_embed1, node_embed2.transpose(1,0))-torch.mm(node_embed2, node_embed1.transpose(1,0))
        # # A0 = F.normalize(A0, p=2, dim=-1)
        # # mean = A0.mean()
        # # std = A0.std()
        # # A0 = (A0 - mean) / std
        # # A0 = A0 + self.correlated_matrix.to(id_list.device)
        # A = F.relu(torch.tanh(self.tanh_alpha * A0))
        # mask = torch.zeros(id_list.size(0), id_list.size(0)).to(id_list.device)
        # # 选择邻接矩阵中的topk
        # s1,t1 = (A + torch.rand_like(A)*0.01).topk(self.top_k,1)
        # mask.scatter_(1,t1,s1.fill_(1).float())
        # A = A*mask
        # return A, A0
        return self.correlated_matrix.to(id_list.device), self.correlated_matrix.to(id_list.device)


class GraphConv_MTGNN(nn.Module):
    def __init__(self, node_num:int=21, d_node:int=48, top_k:int=8, tanh_alpha:float=3) -> None:
        super(GraphConv_MTGNN, self).__init__()
        self.node_num = node_num
        self.d_node = d_node
        self.top_k = top_k
        self.tanh_alpha = tanh_alpha

        # 为节点生成两套嵌入
        self.g_embed1 = nn.Embedding(node_num, d_node)
        self.g_embed2 = nn.Embedding(node_num, d_node)
        self.linear1 = nn.Linear(d_node, d_node)
        self.linear2 = nn.Linear(d_node, d_node)

    def forward(self, id_list):
        nodevec1 = self.g_embed1(id_list)
        nodevec2 = self.g_embed2(id_list)
        nodevec1 = torch.tanh(3*self.linear1(nodevec1))
        nodevec2 = torch.tanh(3*self.linear2(nodevec2))
        # 计算相关性矩阵，是单向图，对应于原文ReLU(tanh(·))部分
        a = torch.mm(nodevec1, nodevec2.transpose(1,0))-torch.mm(nodevec2, nodevec1.transpose(1,0))
        adj = F.relu(torch.tanh(self.tanh_alpha*a))
        mask = torch.zeros(id_list.size(0), id_list.size(0)).to(id_list.device)
        mask.fill_(float('0'))
        # 选择邻接矩阵中的topk
        s1,t1 = (adj + torch.rand_like(adj)*0.01).topk(self.top_k,1)
        mask.scatter_(1,t1,s1.fill_(1))
        adj = adj*mask
        return adj, a


class mixprop(nn.Module):
    def __init__(self, c_in, c_out, gdep, dropout, alpha, d_model):
        super(mixprop, self).__init__()
        # self.mlp = linear((gdep+1)*c_in, c_out)
        # self.mlp = nn.Linear((gdep+1), 1, bias=False)
        self.mlp = CustomLinear(gdep+1)
        self.d_model = d_model
        self.gdep = gdep
        self.dropout = dropout
        self.alpha = alpha


    def forward(self, x, A):
        # D'_ii = 1+sum_j(A_ij)
        # D = 1 + A.sum(1)
        # A' = A+I
        A = A.to(x.device) + torch.eye(A.size(0)).to(x.device)
        D = A.sum(1)
        h = x
        out = [h.unsqueeze(-2)]
        # out = [h]
        A_tilde = A / D.view(-1, 1)
        for _ in range(self.gdep):
            output = torch.einsum('nwl,vw->nvl', (h.double(), A_tilde.double())).to(x.device)
            output = output.contiguous()
            h = self.alpha * x + (1 - self.alpha) * output
            out.append(h.unsqueeze(-2))
            # out.append(h)
        # [(gdep * batch_size) x channel x (patch_num_0 * d_model)]
        output = torch.cat(out,dim=-2).float()
        # output = self.mlp(output.permute(0, 1, 3, 2))
        # output = output.squeeze()
        output = self.mlp(output)

        # output = torch.cat(out,dim=1).float()
        # output = self.mlp(output)
        return output


class mixprop_MTGNN(nn.Module):
    def __init__(self,c_in,c_out,gdep,dropout,alpha):
        super(mixprop_MTGNN, self).__init__()
        self.nconv = nconv()
        self.mlp = linear_MTGNN((gdep+1)*c_in,c_out)
        self.gdep = gdep
        self.dropout = dropout
        self.alpha = alpha


    def forward(self,x,adj):
        adj = adj + torch.eye(adj.size(0)).to(x.device)
        d = adj.sum(1)
        h = x
        out = [h]
        a = adj / d.view(-1, 1)
        for i in range(self.gdep):
            h = self.alpha*x + (1-self.alpha)*self.nconv(h,a)
            out.append(h)
        ho = torch.cat(out,dim=1)
        ho = self.mlp(ho)
        return ho


class mixprop_RPG(nn.Module):
    def __init__(self, c_in, c_out, gdep, mlp_type, alpha, d_model):
        super(mixprop_RPG, self).__init__()
        if mlp_type == 0:
            self.mlp = linear((gdep+1)*c_in, c_out)
        elif mlp_type == 1:
            self.mlp = nn.Linear((gdep+1), 1, bias=False)
        else:
            self.mlp = CustomLinear(gdep+1)
        self.mlp_type = mlp_type
        self.d_model = d_model
        self.gdep = gdep
        self.alpha = alpha


    def forward(self, x, A):
        # D'_ii = 1+sum_j(A_ij)
        # D = 1 + A.sum(1)
        # A' = A+I
        A = A.to(x.device) + torch.eye(A.size(0)).to(x.device)
        D = A.sum(1)
        h = x
        if self.mlp_type == 0:
            out = [h]
        else:
            out = [h.unsqueeze(-2)]
        A_tilde = A / D.view(-1, 1)
        for _ in range(self.gdep):
            output = torch.einsum('nwl,vw->nvl', (h.double(), A_tilde.double())).to(x.device)
            output = output.contiguous()
            h = self.alpha * x + (1 - self.alpha) * output
            if self.mlp_type == 0:
                out.append(h)
            else:
                out.append(h.unsqueeze(-2))
        if self.mlp_type == 1:
            output = torch.cat(out,dim=-2).float()
            output = self.mlp(output.permute(0, 1, 3, 2))
            output = output.squeeze()
        elif self.mlp_type == 2:
            output = torch.cat(out,dim=-2).float()
            output = self.mlp(output)
        else:
            output = torch.cat(out,dim=1).float()
            output = self.mlp(output)
        return output


class LayerNorm(nn.Module):
    __constants__ = ['normalized_shape', 'weight', 'bias', 'eps', 'elementwise_affine']
    def __init__(self, normalized_shape, eps=1e-5):
        super(LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = tuple(normalized_shape)
        self.eps = eps
        if self.elementwise_affine:
            self.weight = nn.Parameter(torch.Tensor(*normalized_shape))
            self.bias = nn.Parameter(torch.Tensor(*normalized_shape))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        self.reset_parameters()


    def reset_parameters(self):
        if self.elementwise_affine:
            nn.init.ones_(self.weight)
            nn.init.zeros_(self.bias)

    def forward(self, x):
        return F.layer_norm(x, tuple(x.shape), self.weight, self.bias, self.eps)



class LayerNorm_MTGNN(nn.Module):
    __constants__ = ['normalized_shape', 'weight', 'bias', 'eps', 'elementwise_affine']
    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
        super(LayerNorm_MTGNN, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = tuple(normalized_shape)
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        if self.elementwise_affine:
            self.weight = nn.Parameter(torch.Tensor(*normalized_shape))
            self.bias = nn.Parameter(torch.Tensor(*normalized_shape))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        self.reset_parameters()


    def reset_parameters(self):
        if self.elementwise_affine:
            nn.init.ones_(self.weight)
            nn.init.zeros_(self.bias)

    def forward(self, input, idx):
        if self.elementwise_affine:
            return F.layer_norm(input, tuple(input.shape[1:]), self.weight[:,idx,:], self.bias[:,idx,:], self.eps)
        else:
            return F.layer_norm(input, tuple(input.shape[1:]), self.weight, self.bias, self.eps)

    def extra_repr(self):
        return '{normalized_shape}, eps={eps}, ' \
            'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
