from typing import Callable, Optional
import torch
from torch import nn
from layers.GRformer_layer import TruncateModule, GCEncoder, Flatten_Head
from layers.Position_Embed import PositionalEncoding, positional_encoding, LocalRNN, RNN
from layers.RevIn import RevIN
from layers.Graph import GraphConv
import warnings

warnings.filterwarnings('ignore')


class GRformer(nn.Module):
    def __init__(self, args, corr=None, high_correlated_count=None) -> None:
        super(GRformer, self).__init__()
        # 对用到的变量进行一些基础配置
        self.batch_size = args.batch_size
        self.channels = args.enc_in
        self.d_model = args.d_model
        self.e_layers = args.e_layers
        self.dropout = args.dropout
        self.use_rnn = True if args.rnn > 0 else False
        # RevIn
        self.revin = args.revin
        if self.revin:
            self.revin_layer = RevIN(self.channels, affine=args.affine, subtract_last=False)

        # 设置多重分片，如需填充，填充长度需要根据输入长度；否则，无法整除时，pading为False则取截断操作，padding为True取尾部填充操作
        self.seq_len = args.seq_len
        self.patch_len = args.patch_len
        self.stride = self.patch_len//2 if args.stride=="half" else self.patch_len

        if args.seq_len % self.stride==0:                    # 能整除说明其实不用强行填充
            self.patch_num = int((self.seq_len - self.patch_len)/self.stride + 1)
            self.process_layer = nn.Identity()                      # 直接原封不动输出
        else:
            # 对于长度不够的情况，填充的策略是尾部填充
            if args.padding_patch=="end":
                padding_length = self.stride - (self.seq_len % self.stride)
                self.patch_num = int((self.seq_len - self.patch_len)/self.stride + 2)
                self.process_layer = nn.ReplicationPad1d((0, padding_length))
            # 非填充的策略是直接截断
            else:
                truncated_length = self.seq_len - (self.seq_len % self.stride)
                self.patch_num = int((self.seq_len - self.patch_len)/self.stride + 1)
                self.process_layer = TruncateModule(truncated_length)
        self.embedPatch = nn.Linear(self.patch_len, self.d_model)
        
        # 设置位置编码
        if self.use_rnn:
            if args.rnn == 1:
                self.E_P = LocalRNN(args.d_model, args.d_model)
            else:
                self.E_P = RNN(args.d_model, args.d_model)
        else:
            self.E_P = positional_encoding(args.pos_embed_type, True, self.patch_num, self.d_model)
        self.dropP = nn.Dropout(self.dropout)

        # 初始化邻接矩阵
        self.use_gcn = args.use_gcn
        if self.use_gcn:
            self.gc = GraphConv(corr=corr, high_correlated_count=high_correlated_count, 
                                node_num=self.channels, d_node=args.d_node, top_k=args.subgraph_size, 
                                tanh_alpha=args.tanh_alpha)
            self.id_list = torch.arange(self.channels)
        
        # 设置多重分片下共享参数的Encoder
        self.backbone = nn.ModuleList([GCEncoder(
            channels=args.enc_in, d_model=self.d_model, n_heads=args.n_heads, 
            dropout=self.dropout, use_gcn=args.use_gcn, prop_alpha=args.prop_alpha, gcn_depth=args.gcn_depth, mlp_type=args.mlp_type,
            activation=args.activation
        ) for _ in range(self.e_layers)])
        
        # 对结果经过线性层

        self.head_nf = self.d_model * self.patch_num
        self.head = Flatten_Head(args.individual, self.channels, self.head_nf, args.pred_len, head_dropout=args.head_dropout)


    def forward(self, x):
        if self.use_gcn:
            A, A0 = self.gc(self.id_list.to(x.device))
        else:
            A, A0 = None, None
        
        # x: [batch_size x input_length x channel] -> [batch size x channel x input length]
        if self.revin:
            x = self.revin_layer(x, 'norm')
        x = x.permute(0, 2, 1)
        # do patching
        pad_result = self.process_layer(x)
        # [batch_size x channel x seq_len] -> [batch_size x channel x patch_num_0 x patch_len_0]
        pad_result = pad_result.unfold(dimension=-1, size=self.patch_len, step=self.stride)
        # [batch_size x channel x patch_num_0 x patch_len_0] -> [batch_size x channel x patch_num_0 x dim_model]
        pad_result = self.embedPatch(pad_result)
        shape = pad_result.shape
        # [batch_size x channel x patch_num_0 x patch_len_0] -> [(batch_size * channel) x patch_num_0 x dim_model]
        pad_result = torch.reshape(pad_result, (shape[0]*shape[1], shape[2], shape[3]))
        # 加入位置编码和patch长度编码
        if self.use_rnn:
            pad_result = pad_result + self.E_P(pad_result)
        else:
            pad_result = pad_result + self.E_P.to(x.device)
        # 拼接不同patch尺度下得到的嵌入，本次设计采用共享参数的Encoder
        output = self.dropP(pad_result)

        # go through several Encoders
        for index in range(self.e_layers):
            output = self.backbone[index](output, A)
        
        # output: [batch_size * channel x patch_num_total x dim_model] -> [batch_size x channel x target_window]
        output = torch.reshape(output, (-1, self.channels, output.shape[-2], output.shape[-1]))
        output = self.head(output)
        output = output.permute(0,2,1)
        if self.revin:
            output = self.revin_layer(output, 'denorm')
        if self.use_gcn:
            return output, A0, A
        else:
            return output
