import torch.nn as nn
import torch.nn.functional as F
import torch

class ConvLayer(nn.Module):
    def __init__(self, c_in):
        super(ConvLayer, self).__init__()
        self.downConv = nn.Conv1d(in_channels=c_in,
                                  out_channels=c_in,
                                  kernel_size=3,
                                  padding=2,
                                  padding_mode='circular')
        self.norm = nn.BatchNorm1d(c_in)
        self.activation = nn.ELU()
        self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        x = self.downConv(x.permute(0, 2, 1))
        x = self.norm(x)
        x = self.activation(x)
        x = self.maxPool(x)
        x = x.transpose(1, 2)
        return x

# dynamic graph learning
class DGL(nn.Module):
    def __init__(self, configs, d_len, hops):
        super(DGL, self).__init__()
        self.d_len = d_len
        self.dynamicGNN = DynamicGraphUpdate(configs, d_len, hops)
        self.agg_mlp = torch.nn.Conv1d(d_len, configs.d_model, kernel_size=1, padding=0, stride=1, bias=True)

    def forward(self, x):  # [B,N,D]
        Xout, adj_structure = self.dynamicGNN(x)  # Xout[B,D,N]
        Xout = self.agg_mlp(Xout)
        return Xout, adj_structure
class DynamicGraphUpdate(nn.Module):
    def __init__(self, configs,deep_len, hops):
        super(DynamicGraphUpdate, self).__init__()
        self.enc_in = configs.enc_in
        self.d_model = configs.d_model
        self.deep_len = deep_len
        self.dropout = configs.dropout
        self.nd = configs.nodedim

        self.nodeEmbedding_1 = nn.Parameter(torch.randn(self.enc_in, self.nd))
        self.nodeEmbedding_2 = nn.Parameter(torch.randn(self.nd, self.enc_in))

        self.nodeEmb_gate1 = nn.Sequential(nn.Linear(self.deep_len + self.nd, 1), nn.Tanh(), nn.ReLU())

        self.nodeEmb_gate2 = nn.Sequential(nn.Linear(self.deep_len + self.nd, 1), nn.Tanh(), nn.ReLU())

        self.nodeLinear1 = nn.Linear(self.deep_len, self.nd)
        self.nodeLinear2 = nn.Linear(self.deep_len, self.nd)

        self.mhGNN = GraphConv(self.deep_len, self.deep_len, self.dropout, multiHop=hops)

    def forward(self, x):
        B, _, _ = x.size()
        nodeEmb_1 = self.nodeEmbedding_1.view(1, self.enc_in, self.nd).repeat(B, 1, 1)
        nodeEmb_2 = self.nodeEmbedding_2.view(1, self.nd, self.enc_in).repeat(B, 1, 1)

        nodeGate_1 = self.nodeEmb_gate1(torch.cat([x, nodeEmb_1], dim=-1))
        nodeGate_2 = self.nodeEmb_gate2(torch.cat([x, nodeEmb_2.permute(0, 2, 1)], dim=-1))

        xL1 = nodeGate_1 * self.nodeLinear1(x)
        xL2 = nodeGate_2 * self.nodeLinear2(x)

        nodevector_1 = nodeEmb_1 + xL1
        nodevector_2 = nodeEmb_2 + xL2.permute(0, 2, 1)

        A_out = F.softmax(F.relu(torch.matmul(nodevector_1, nodevector_2)), dim=-1)

        adj_output = A_out

        A_out = [A_out]
        x = x.permute(0, 2, 1)
        x = self.mhGNN(x, A_out)
        return x, adj_output
class gconv(nn.Module):
    def __init__(self):
        super(gconv, self).__init__()

    def forward(self, x, A):
        x = torch.einsum('bfn,bnv->bfv', (x, A))
        return x.contiguous()
class GraphConv(nn.Module):
    def __init__(self, c_in, c_out, dropout, multiHop=2):
        super(GraphConv, self).__init__()
        self.gconv = gconv()
        c_in = (multiHop + 1) * c_in
        self.linear = torch.nn.Conv1d(c_in, c_out, kernel_size=1, padding=0, stride=1, bias=True)
        self.dropout = dropout
        self.multiHop = multiHop

    def forward(self, x, adj):  # [B,D,N]
        multi_X = [x]
        for a in adj:
            x1 = self.gconv(x, a)
            multi_X.append(x1)
            for k in range(2, self.multiHop + 1):
                x2 = self.gconv(x1, a)
                multi_X.append(x2)
                x1 = x2

        x_cat = torch.cat(multi_X, dim=1)
        x_cat = self.linear(x_cat)  # [B,D,N]
        return F.relu(x_cat)
# dynamic graph learning

class EncoderLayer(nn.Module):
    def __init__(self, attention, d_model, d_ff=None, config=None,dropout=0.1, activation="relu"):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.attention = attention
        # self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        # self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

        self.dgl = DGL(config, d_model, config.order)  # best order_1  2_worse


    def forward(self, x, attn_mask=None, tau=None, delta=None):
        new_x, attn = self.attention(
            x, x, x,
            attn_mask=attn_mask,
            tau=tau, delta=delta
        )
        x = x + self.dropout(new_x)

        y = x = self.norm1(x)
        # y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))  # replace
        # y = self.dropout(self.conv2(y).transpose(-1, 1))

        # dynamic graph learning (DGL)
        output, adjacency_matrix = self.dgl(y,attn)
        y = self.dropout(output)
        y = y.permute(0, 2, 1)

        return self.norm2(x + y), attn, adjacency_matrix


class Encoder(nn.Module):
    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
        super(Encoder, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
        self.norm = norm_layer

    def forward(self, x, attn_mask=None, tau=None, delta=None):  # attn_mask=None
        # x [B, L, D]   x[32,11,256]
        attns = []
        adjs = []
        if self.conv_layers is not None:
            for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):
                delta = delta if i == 0 else None
                x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
                x = conv_layer(x)
                attns.append(attn)
            x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
            attns.append(attn)
        else:  # here
            for attn_layer in self.attn_layers:
                x, attn, adj = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)   # x[32,7,96]  attn[32,8,7,7]
                attns.append(attn)
                adjs.append(adj)

        if self.norm is not None:
            x = self.norm(x)

        return x, attns, adjs


class DecoderLayer(nn.Module):
    def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
                 dropout=0.1, activation="relu"):
        super(DecoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.self_attention = self_attention
        self.cross_attention = cross_attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
        x = x + self.dropout(self.self_attention(
            x, x, x,
            attn_mask=x_mask,
            tau=tau, delta=None
        )[0])
        x = self.norm1(x)

        x = x + self.dropout(self.cross_attention(
            x, cross, cross,
            attn_mask=cross_mask,
            tau=tau, delta=delta
        )[0])

        y = x = self.norm2(x)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))

        return self.norm3(x + y)


class Decoder(nn.Module):
    def __init__(self, layers, norm_layer=None, projection=None):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList(layers)
        self.norm = norm_layer
        self.projection = projection

    def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
        for layer in self.layers:
            x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta)

        if self.norm is not None:
            x = self.norm(x)

        if self.projection is not None:
            x = self.projection(x)
        return x
