import torch
import torch.nn as nn
import torch.nn.functional as F
from layers.Graph import GraphConv_MTGNN, mixprop_MTGNN, LayerNorm_MTGNN
from layers.MTGNN_layers import dilated_inception
from layers.RevIn import RevIN

class gtnet(nn.Module):
    def __init__(self, args, in_dim=1, layer_norm_affline=True):
        super(gtnet, self).__init__()
        self.gcn_true = args.use_gcn
        self.buildA_true = True
        self.num_nodes = args.enc_in
        self.dropout = args.dropout
        self.predefined_A = args.predefined_A

        self.revin = args.revin
        if self.revin:
            self.revin_layer = RevIN(in_dim, affine=args.affine, subtract_last=False)
        
        self.filter_convs = nn.ModuleList()
        self.gate_convs = nn.ModuleList()
        self.residual_convs = nn.ModuleList()
        self.skip_convs = nn.ModuleList()
        self.gconv1 = nn.ModuleList()
        self.gconv2 = nn.ModuleList()
        self.norm = nn.ModuleList()
        self.start_conv = nn.Conv2d(in_channels=in_dim,
                                    out_channels=args.residual_channels,
                                    kernel_size=(1, 1))
        # 图学习层
        self.gc = GraphConv_MTGNN(self.num_nodes, args.d_node, args.subgraph_size, tanh_alpha=args.tanh_alpha)
        # 扩展卷积设置
        self.seq_length = args.seq_len
        kernel_size = 7
        if args.dilation_exponential>1:
            self.receptive_field = int(1+(kernel_size-1)*(args.dilation_exponential**args.e_layers-1)/(args.dilation_exponential-1))
        else:
            self.receptive_field = args.e_layers*(kernel_size-1) + 1

        for i in range(1):
            if args.dilation_exponential>1:
                rf_size_i = int(1 + i*(kernel_size-1)*(args.dilation_exponential**args.e_layers-1)/(args.dilation_exponential-1))
            else:
                rf_size_i = i*args.e_layers*(kernel_size-1)+1
            new_dilation = 1
            for j in range(1, args.e_layers+1):
                if args.dilation_exponential > 1:
                    rf_size_j = int(rf_size_i + (kernel_size-1)*(args.dilation_exponential**j-1)/(args.dilation_exponential-1))
                else:
                    rf_size_j = rf_size_i+j*(kernel_size-1)

                self.filter_convs.append(dilated_inception(args.residual_channels, args.conv_channels, dilation_factor=new_dilation))
                self.gate_convs.append(dilated_inception(args.residual_channels, args.conv_channels, dilation_factor=new_dilation))
                self.residual_convs.append(nn.Conv2d(in_channels=args.conv_channels,
                                                    out_channels=args.residual_channels,
                                                    kernel_size=(1, 1)))
                if self.seq_length>self.receptive_field:
                    self.skip_convs.append(nn.Conv2d(in_channels=args.conv_channels,
                                                    out_channels=args.skip_channels,
                                                    kernel_size=(1, self.seq_length-rf_size_j+1)))
                else:
                    self.skip_convs.append(nn.Conv2d(in_channels=args.conv_channels,
                                                    out_channels=args.skip_channels,
                                                    kernel_size=(1, self.receptive_field-rf_size_j+1)))

                if self.gcn_true:
                    self.gconv1.append(mixprop_MTGNN(args.conv_channels, args.residual_channels, args.gcn_depth, args.dropout, args.prop_alpha))
                    self.gconv2.append(mixprop_MTGNN(args.conv_channels, args.residual_channels, args.gcn_depth, args.dropout, args.prop_alpha))

                if self.seq_length>self.receptive_field:
                    self.norm.append(LayerNorm_MTGNN((args.residual_channels, self.num_nodes, self.seq_length - rf_size_j + 1),elementwise_affine=layer_norm_affline))
                else:
                    self.norm.append(LayerNorm_MTGNN((args.residual_channels, self.num_nodes, self.receptive_field - rf_size_j + 1),elementwise_affine=layer_norm_affline))
                # 扩张卷积的扩张系数变化
                new_dilation *= args.dilation_exponential
        # 最后两层1x1 conv层
        self.layers = args.e_layers
        self.end_conv_1 = nn.Conv2d(in_channels=args.skip_channels, out_channels=args.end_channels, kernel_size=(1,1), bias=True)
        self.end_conv_2 = nn.Conv2d(in_channels=args.end_channels, out_channels=args.pred_len, kernel_size=(1,1), bias=True)
        if self.seq_length > self.receptive_field:
            self.skip0 = nn.Conv2d(in_channels=in_dim, out_channels=args.skip_channels, kernel_size=(1, self.seq_length), bias=True)
            self.skipE = nn.Conv2d(in_channels=args.residual_channels, out_channels=args.skip_channels, kernel_size=(1, self.seq_length-self.receptive_field+1), bias=True)

        else:
            self.skip0 = nn.Conv2d(in_channels=in_dim, out_channels=args.skip_channels, kernel_size=(1, self.receptive_field), bias=True)
            self.skipE = nn.Conv2d(in_channels=args.residual_channels, out_channels=args.skip_channels, kernel_size=(1, 1), bias=True)
        self.idx = torch.arange(self.num_nodes)


    def forward(self, x, idx=None):

        if self.revin:
            x = self.revin_layer(x, 'norm')
        
        x = x.unsqueeze(-1).permute(0, 3, 2, 1)

        seq_len = x.size(3)
        assert seq_len==self.seq_length, 'input sequence length not equal to preset sequence length'

        if self.seq_length<self.receptive_field:
            x = nn.functional.pad(x,(self.receptive_field-self.seq_length,0,0,0))

        if self.gcn_true:
            if self.buildA_true:
                if idx is None:
                    adp, A0 = self.gc(self.idx.to(x.device))
                else:
                    adp, A0 = self.gc(idx.to(x.device))
            else:
                # 改为直接从数据集里加载
                adp = self.predefined_A
                A0 = None

        output = self.start_conv(x)
        # 是最下边,一开始的skip
        skip = self.skip0(F.dropout(x, self.dropout, training=self.training))
        for i in range(self.layers):
            residual = output
            # 时序卷积的tanh和sigmoid
            filter = self.filter_convs[i](output)
            filter = torch.tanh(filter)
            gate = self.gate_convs[i](output)
            gate = torch.sigmoid(gate)
            output = filter * gate
            output = F.dropout(output, self.dropout, training=self.training)
            # 时间卷积输出x,输入到skip Connection,一个32转64-channel的层
            s = output
            s = self.skip_convs[i](s)
            skip = s + skip
            # 图卷积，双向都计算，最后采用加和方式
            if self.gcn_true:
                output = self.gconv1[i](output, adp)+self.gconv2[i](output, adp.transpose(1,0))
            else:
                output = self.residual_convs[i](output)
            # +residual Connections，在时间上做了截断操作
            output = output + residual[:, :, :, -output.size(3):]
            if idx is None:
                output = self.norm[i](output, self.idx)
            else:
                output = self.norm[i](output, idx)

        skip = self.skipE(output) + skip
        output = F.relu(skip)
        output = F.relu(self.end_conv_1(output))
        output = self.end_conv_2(output).squeeze()
        if self.revin:
            output = self.revin_layer(output, 'denorm')
        return output, A0, adp

class GRUNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(GRUNet, self).__init__()
        self.gru = nn.GRU(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, input):
        output, _ = self.gru(input)
        print(output.shape)
        output = self.fc(output)  # 只使用最后一个时间步的输出
        print(output.shape)
        return output