import torch
import torch.nn as nn
from layers.StandardNorm import myNormalize
from layers import STGCN_layers as layers


class STGCNChebGraphConv(nn.Module):
    # STGCNChebGraphConv contains 'TGTND TGTND TNFF' structure
    # ChebGraphConv is the graph convolution from ChebyNet.
    # Using the Chebyshev polynomials of the first kind as a graph filter.

    # T: Gated Temporal Convolution Layer (GLU or GTU)
    # G: Graph Convolution Layer (ChebGraphConv)
    # T: Gated Temporal Convolution Layer (GLU or GTU)
    # N: Layer Normolization
    # D: Dropout

    # T: Gated Temporal Convolution Layer (GLU or GTU)
    # G: Graph Convolution Layer (ChebGraphConv)
    # T: Gated Temporal Convolution Layer (GLU or GTU)
    # N: Layer Normolization
    # D: Dropout

    # T: Gated Temporal Convolution Layer (GLU or GTU)
    # N: Layer Normalization
    # F: Fully-Connected Layer
    # F: Fully-Connected Layer

    def __init__(self, Kt, Ks, graph_conv_type, gso, dropout, n_his, blocks, n_vertex):
        super(STGCNChebGraphConv, self).__init__()
        modules = []
        for l in range(len(blocks) - 3):
            modules.append(layers.STConvBlock(Kt, Ks, n_vertex, blocks[l][-1], blocks[l + 1], 'glu',
                                              graph_conv_type, gso, True, dropout))
        self.st_blocks = nn.Sequential(*modules)
        Ko = n_his - (len(blocks) - 3) * 2 * (Kt - 1)
        self.Ko = Ko
        if self.Ko > 1:
            self.output = layers.OutputBlock(Ko, blocks[-3][-1], blocks[-2], blocks[-1][0], n_vertex, 'glu',
                                             True, dropout)
        elif self.Ko == 0:
            self.fc1 = nn.Linear(in_features=blocks[-3][-1], out_features=blocks[-2][0], bias=True)
            self.fc2 = nn.Linear(in_features=blocks[-2][0], out_features=blocks[-1][0], bias=True)
            self.relu = nn.ReLU()
            self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        x = self.st_blocks(x)
        if self.Ko > 1:
            x = self.output(x)
        elif self.Ko == 0:
            x = self.fc1(x.permute(0, 2, 3, 1))
            x = self.relu(x)
            x = self.fc2(x).permute(0, 3, 1, 2)

        return x


class Model(nn.Module):
    # STGCNGraphConv contains 'TGTND TGTND TNFF' structure
    # GraphConv is the graph convolution from GCN.
    # GraphConv is not the first-order ChebConv, because the renormalization trick is adopted.
    # Be careful about over-smoothing.

    # T: Gated Temporal Convolution Layer (GLU or GTU)
    # G: Graph Convolution Layer (GraphConv)
    # T: Gated Temporal Convolution Layer (GLU or GTU)
    # N: Layer Normolization
    # D: Dropout

    # T: Gated Temporal Convolution Layer (GLU or GTU)
    # G: Graph Convolution Layer (GraphConv)
    # T: Gated Temporal Convolution Layer (GLU or GTU)
    # N: Layer Normolization
    # D: Dropout

    # T: Gated Temporal Convolution Layer (GLU or GTU)
    # N: Layer Normalization
    # F: Fully-Connected Layer
    # F: Fully-Connected Layer

    def __init__(self, configs):
        super(Model, self).__init__()
        self.configs = configs
        self.num_blocks = configs.e_layers
        self.d_model = configs.d_model
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.gso = configs.d_model
        self.Kt = 5
        self.Ks = 5
        self.graph_conv_type = 'cheb_graph_conv'

        modules = []
        for l in range(self.num_blocks):
            modules.append(layers.STConvBlock(self.Kt, self.Ks, configs.enc_in, self.d_model, self.d_model, 'glu',
                                              self.graph_conv_type, self.gso, True, self.configs.dropout))
        self.st_blocks = nn.Sequential(*modules)
        Ko = self.seq_len - self.num_blocks * 2 * (self.Kt - 1)
        self.Ko = Ko
        self.output = layers.OutputBlock(Ko, self.d_model, self.d_model, self.pred_len, configs.enc_in, 'glu',
                                         True, self.configs.dropout)
        self.instance_norm = myNormalize(configs.enc_in, affine=True)

    def forward(self, x, target_x=None, x_enc_mark=None, x_dec=None, x_dec_mark=None):
        x = self.instance_norm(x, target_x, 'norm')
        x = self.st_blocks(x)
        x = self.output(x)
        x = self.instance_norm(x, mode='denorm')
        return x