import torch.nn as nn
import torch.nn.functional as F
import torch

from layers.SelfAttention_Family import FullAttention, AttentionLayer, ProbAttention


class ConvLayer(nn.Module):
    def __init__(self, c_in):
        super(ConvLayer, self).__init__()
        self.downConv = nn.Conv1d(in_channels=c_in,
                                  out_channels=c_in,
                                  kernel_size=3,
                                  padding=2,
                                  padding_mode='circular')
        self.norm = nn.BatchNorm1d(c_in)
        self.activation = nn.ELU()
        self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        x = self.downConv(x.permute(0, 2, 1))
        x = self.norm(x)
        x = self.activation(x)
        x = self.maxPool(x)
        x = x.transpose(1, 2)
        return x


class EncoderLayer(nn.Module):
    def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu", **kwargs):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, fft, attn_mask=None, tau=None, delta=None, **kwargs):
        new_x, attn = self.attention(
            x, fft, x,
            attn_mask=attn_mask,
            tau=tau, delta=delta,
            **kwargs
        )
        x = x + self.dropout(new_x)

        y = x = self.norm1(x)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))

        return self.norm2(x + y), attn


class Encoder(nn.Module):
    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
        super(Encoder, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
        self.norm = norm_layer

    def forward(self, x, fft, attn_mask=None, tau=None, delta=None, **kwargs):
        # x [B, L, D]
        attns = []
        if self.conv_layers is not None:
            for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):
                delta = delta if i == 0 else None
                x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
                x = conv_layer(x)
                attns.append(attn)
            x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
            attns.append(attn)
        else:
            for attn_layer in self.attn_layers:
                x, attn = attn_layer(x, fft, attn_mask=attn_mask, tau=tau, delta=delta, **kwargs)
                attns.append(attn)

        if self.norm is not None:
            x = self.norm(x)

        return x, attns


# class EncoderLayer(nn.Module):
#     def __init__(self, d_model, node_num, batch_size, enc_in, n_heads, dropout=0.1, activation="relu", gdep=3, alpha=0.3):
#         super(EncoderLayer, self).__init__()
#
#         # self.attention = attention
#         self.norm = nn.LayerNorm(d_model)
#         self.dropout = nn.Dropout(dropout)
#         self.activation = F.relu if activation == "relu" else F.gelu
#         self.hgnn = HyperGraph(activation)
#         self.HyperEdge_num = 5
#         # self.metrix_para = 100
#         # self.metrix1 = nn.Parameter(torch.randn(batch_size*enc_in, node_num, node_num), requires_grad=True)
#         # self.metrix2 = nn.Parameter(torch.randn(batch_size*enc_in, self.HyperEdge_num, node_num), requires_grad=True)
#         self.metrix1 = nn.Parameter(torch.randn(node_num, node_num), requires_grad=True)
#         self.metrix2 = nn.Parameter(torch.randn(self.HyperEdge_num, node_num), requires_grad=True)
#         # self.softmax = nn.Softmax(dim=1)
#         # self.gcn1 = GCN(in_features=d_model, hidden_size=d_model*2, out_features=d_model, activation=activation)
#         # self.gcn2 = GCN(in_features=d_model, hidden_size=d_model*2, out_features=d_model, activation=activation)
#         # self.encoder_output = nn.Linear(d_model*n_heads, d_model)
#         self.K = 5
#
#     def forward(self, x):
#         # x [Batch, Patch, d_model]
#         # if self.training == 0:
#         #     print(self.metrix1.shape, x.shape)
#         # Ev = torch.matmul(self.metrix1, x)  # b, node, d_model
#         # Eh = torch.matmul(self.metrix2, x)  # b, edge, d_model
#         Ev = torch.einsum('bnd,nn->bnd', (x, self.metrix1))
#         Eh = torch.einsum('bnd,en->bed', (x, self.metrix2))
#         # print(Ev.shape, Eh.shape)
#
#         H = F.softmax(self.activation(torch.matmul(Ev, Eh.permute(0, 2, 1).contiguous())))
#
#         values, indices = torch.topk(H, self.K, dim=1)
#         top_H = torch.zeros_like(H)
#         top_H.scatter_(1, indices, values)
#         top_H = torch.masked_fill(top_H, top_H < 0.5 , 0)
#         # print(H.shape)
#         out = self.hgnn(x, top_H)
#         # if self.training == 0:
#         #     print(out.shape)
#         return self.norm(out + x), None

        # attn = self.attention(x, x)  # [Batch Head Patch Patch]

        # attn_sum = torch.sum(attn, dim=-1)
        # # arr_sum.shape
        # B, H, P = attn_sum.shape
        # attn_sum = attn_sum.reshape(B * H, P)
        # # arr_sum.shape
        # # adp = []
        # sorted_indices = torch.argsort(attn_sum, descending=True, dim=-1)
        # # print(sorted_indices)
        # # print('end')
        # # adp = torch.zeros((B * H, P, P)).to(x.device)
        # # for i in range(B * H):
        # #
        # #     for j in range(P):
        # #         for k in range(1, self.N + 1):
        # #             if j * self.N + k < P:
        # #                 adp[i][sorted_indices[i][j * self.N + k]][sorted_indices[i][j]] = 1
        # #             else:
        # #                 break
        # #
        # # adp = adp.reshape(B, H, P, P)
        # # adp = adp * attn
        #
        # index = torch.arange(P).to(x.device)
        # index = ((index - 1) / self.N).int()
        # index = index[1:]
        # index = index.unsqueeze(0)
        # index = index.repeat(sorted_indices.shape[0], 1).to(torch.int64)
        # edge_end = torch.gather(sorted_indices, dim=-1, index=index)
        # edge_start = sorted_indices[:, 1:]
        # edges = torch.stack([edge_start, edge_end])
        # edges = edges.permute(1, 0, 2).contiguous()  # [B*H, 2, P-1] P-1 edges
        #
        # # weight = attn.reshape(B * H, P, P)
        # # values = weight[torch.arange(weight.size(0)).unsqueeze(1), edge_start, edge_end]
        # ndata = x.repeat(1, H, 1, 1).reshape(B, H, P, -1)
        # # print(ndata.shape)
        #
        # # graphs = []
        # # for i in range(edges.shape[0]):
        # #     g = dgl.graph((edge_start[i], edge_end[i]), num_nodes=P, device=x.device)
        # #     g.ndata['x'] = ndata[i]
        # #     g.edata['weight'] = values[i]
        # #     graphs.append(g)
        # # graphs = dgl.batch(graphs)
        #
        # adp = []
        # for i in range(edges.shape[0]):
        #     temp_adp = torch.sparse_coo_tensor(edges[i], torch.ones(P-1).to(x.device), size=[P, P])
        #     temp_adp = temp_adp.to_dense()
        #     adp.append(temp_adp)
        #
        # adp = torch.stack(adp)
        # adp = adp.reshape(B, H, P, P)
        # adp = adp * attn
        #
        #
        # out = self.gcn1(adp, ndata)
        # out = self.activation(out)
        # out = self.gcn2(adp, out)
        # out = self.dropout(out)
        # # print(out.shape)
        # # out = out.reshape(B, H, P, -1)
        # out = out.permute(0, 2, 1, 3).contiguous()
        # B, N, H, _ = out.shape
        # out = out.reshape(B, N, -1)
        # # print(out.shape)
        # out = self.activation(self.encoder_output(out))

        # return self.norm(out + x), attn




# class Encoder(nn.Module):
#     def __init__(self, attn_layers, norm_layer):
#         super(Encoder, self).__init__()
#         self.attn_layers = nn.ModuleList(attn_layers)
#
#     def forward(self, x):
#         # x [B, L, D]
#         attns = []
#         for attn_layer in self.attn_layers:
#             x, attn = attn_layer(x)
#             attns.append(attn)
#
#         return x, attns

class mixprop(nn.Module):
    def __init__(self,c_in,c_out,gdep,dropout,alpha):
        super(mixprop, self).__init__()
        self.nconv = nconv()
        self.mlp = linear((gdep+1)*c_in,c_out)
        self.gdep = gdep
        self.dropout = dropout
        self.alpha = alpha

    def forward(self, x, adj):
        adj = adj + torch.eye(adj.size(-1)).to(x.device)
        B, H, N, _ = adj.shape
        d = adj.sum(-1)
        x = x.unsqueeze(1)
        x = x.repeat(1, H, 1, 1)
        h = x
        out = [h]
        # print(adj.shape)
        a = adj / d.view(B, H, N, 1)
        for i in range(self.gdep):
            h = self.alpha*x + (1-self.alpha)*self.nconv(h,a)
            out.append(h)
        ho = torch.cat(out,dim=-1)
        ho = self.mlp(ho)
        return ho


class nconv(nn.Module):
    def __init__(self):
        super(nconv,self).__init__()

    def forward(self, x, A):
        # print(x.shape, A.shape)
        x = torch.einsum('bhnd,bhnn->bhnd',(x,A))
        # print(x.shape)
        return x.contiguous()


class linear(nn.Module):
    def __init__(self,c_in,c_out,bias=True):
        super(linear,self).__init__()
        self.mlp = torch.nn.Linear(c_in, c_out, bias=bias)

    def forward(self,x):
        return self.mlp(x)



class DecoderLayer(nn.Module):
    def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
                 dropout=0.1, activation="relu"):
        super(DecoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.self_attention = self_attention
        self.cross_attention = cross_attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
        x = x + self.dropout(self.self_attention(
            x, x, x,
            attn_mask=x_mask,
            tau=tau, delta=None
        )[0])
        x = self.norm1(x)

        x = x + self.dropout(self.cross_attention(
            x, cross, cross,
            attn_mask=cross_mask,
            tau=tau, delta=delta
        )[0])

        y = x = self.norm2(x)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))

        return self.norm3(x + y)


class Decoder(nn.Module):
    def __init__(self, layers, norm_layer=None, projection=None):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList(layers)
        self.norm = norm_layer
        self.projection = projection

    def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
        for layer in self.layers:
            x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta)

        if self.norm is not None:
            x = self.norm(x)

        if self.projection is not None:
            x = self.projection(x)
        return x


class GCN(nn.Module):
    def __init__(self, in_features, hidden_size, out_features, activation):
        super(GCN, self).__init__()
        self.conv = nconv()
        self.linear1 = nn.Linear(in_features*2, hidden_size)
        self.linear2 = nn.Linear(hidden_size, out_features)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, graphs, nodes):
        # print(graphs.shape, nodes.shape)
        out = self.conv(nodes, graphs)
        out = torch.cat((nodes, out), dim=-1)
        out = self.linear1(out)
        out = self.activation(out)
        out = self.linear2(out)
        return out


class HyperGraph(nn.Module):
    def __init__(self, activation):
        super(HyperGraph, self).__init__()
        self.activation = F.relu if activation == "relu" else F.gelu
        self.relu = nn.ReLU()

    def forward(self, x, H):
        Dv = torch.eye(H.size(-2)).unsqueeze(0).repeat(H.size(0), 1, 1).to(x.device)
        De = torch.eye(H.size(-1)).unsqueeze(0).repeat(H.size(0), 1, 1).to(x.device)
        Dv = torch.sum(Dv, dim=2)
        De = torch.sum(De, dim=1)
        Dv = torch.diag_embed(1.0 / torch.sqrt(Dv))
        De = torch.diag_embed(1.0 / De)
        # print(Dv.shape, De.shape)
        out = torch.matmul(torch.matmul(Dv, H), De)
        out = torch.matmul(torch.matmul(out, H.permute(0, 2, 1).contiguous()), Dv)

        return self.relu(torch.matmul(out, x))
# class GCN(nn.Module):
#     def __init__(self, in_features, hidden_size, out_features, activation):
#         super(GCN, self).__init__()
#         self.W1 = nn.Linear(in_features*2, hidden_size)
#         self.W2 = nn.Linear(hidden_size, out_features)
#         self.activation = F.relu if activation == "relu" else F.gelu

    # def message_func(self, edges):
    #     # print(edges.data['weight'], edges.src['x'], edges.dst['x'])
    #     # print('m:', edges.src['x'] + edges.data['weight'])
    #     print(edges.data['weight'])
    #     return {'m': edges.src['x'] * edges.data['weight']}

    # def apply_nodes(self, nodes):
    #     x = nodes.data['x']
    #     out = nodes.data['out']
    #     # print(x, out)
    #     score = self.W1(torch.cat([x, out], -1))
    #     score = self.activation(score)
    #     score = self.W2(score)
    #
    #     return {'out': score}

    # def forward(self, graph):
    #     with graph.local_scope():
    #         graph.update_all(message_func=fn.u_mul_e('x', 'weight', 'm'), reduce_func=fn.mean('m', 'out'))
    #         # print(graph.edata['weight'])
    #         graph.apply_nodes(self.apply_nodes)
    #         return graph.ndata['out']




