import torch
from torch import Tensor
from torch import nn
from typing import Optional
from layers.Graph import mixprop_RPG
from layers.Attention import MultiheadAttention
from utils.tools import get_activation_fn

# 自定义的截断层
class TruncateModule(nn.Module):
    def __init__(self, target_length):
        super(TruncateModule, self).__init__()
        self.target_length = target_length

    def forward(self, x, truncate_length):
        return x[: ,: ,:truncate_length]


# Transformer Encoder
class GCEncoder(nn.Module):
    def __init__(self, channels:int=21, d_model:int=128, n_heads:int=8, dropout:float=0., attn_dropout:float=0.,
                 d_k:Optional[int]=None, d_v:Optional[int]=None,  d_ff:int=256, bias:bool=True, 
                 use_gcn:bool=False, prop_alpha:float=0.05, gcn_depth:int=2, mlp_type:int=2,
                 norm:str='BatchNorm', activation="gelu") -> None:
        super(GCEncoder, self).__init__()
        self.channels = channels
        # 多头注意力
        self.MultiHeadAttn = MultiheadAttention(d_model=d_model, n_heads=n_heads, d_k=d_k, d_v=d_v,
                                                attn_dropout=attn_dropout, proj_dropout=dropout, res_attention=False)

        # 图相关的层
        self.use_gcn = use_gcn
        if self.use_gcn:
            # 图卷积层
            # self.gconv1 = mixprop(self.channels, self.channels, gcn_depth, mlp_type, prop_alpha, d_model)
            # self.gconv2 = mixprop(self.channels, self.channels, gcn_depth, mlp_type, prop_alpha, d_model)
            self.gconv = mixprop_RPG(self.channels, self.channels, gcn_depth, mlp_type, prop_alpha, d_model)
        
        # Add & Norm
        self.dropout_attn = nn.Dropout(dropout)
        if "Batch" in norm:
            self.norm_attn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
        else:
            self.norm_attn = nn.LayerNorm(d_model)

        # 前馈神经网络
        self.ff = nn.Sequential(nn.Linear(d_model, d_ff, bias=bias),
                                get_activation_fn(activation),
                                nn.Dropout(dropout),
                                nn.Linear(d_ff, d_model, bias=bias))

        # Add & Norm，对前馈神经网络结果进行归一化
        self.dropout_ffn = nn.Dropout(dropout)
        if "Batch" in norm:
            self.norm_ffn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
        else:
            self.norm_ffn = nn.LayerNorm(d_model)


    def forward(self, x, A) -> Tensor:
        # 多头注意层
        # x: [batch_size * channels x patch_num_total x dim_model]
        result, _ = self.MultiHeadAttn(x)
        # 图卷积
        if self.use_gcn:
            # [batch_size * channel x patch_num_total x dim_model] -> [batch_size * channel x patch_num_0 x dim_model]
            current_patch_type = result
            # [batch_size * channel x patch_num_0 x dim_model] -> [batch_size x channel x patch_num_0 x dim_model]
            shape = current_patch_type.shape
            current_patch_type = torch.reshape(current_patch_type, (-1, self.channels, shape[-2], shape[-1]))
            # [batch_size x channel x patch_num_0 x dim_model] -> [batch_size x channel x patch_num_0 * dim_model]
            shape = current_patch_type.shape
            current_patch_type = torch.reshape(current_patch_type, (shape[0], shape[1], shape[-2]*shape[-1]))
            # current_patch_type = self.gconv1(current_patch_type, A) + self.gconv2(current_patch_type, A.transpose(1, 0))
            current_patch_type = self.gconv(current_patch_type, A)
            # [batch_size x channel x patch_num_0 * dim_model] -> [batch_size * channel x patch_num_0 x dim_model]
            current_patch_type = torch.reshape(current_patch_type, (shape[0], shape[1], shape[-2], shape[-1]))
            current_patch_type = torch.reshape(current_patch_type, (shape[0] * shape[1], shape[-2], shape[-1]))
            result = current_patch_type
            # result = result / 2

        # Add & Norm
        output = x + self.dropout_attn(result)          # residual
        output = self.norm_attn(output)
        
        # 前馈神经网络
        result = self.ff(output)
        # Add & Norm
        output = output + self.dropout_ffn(result)
        output = self.norm_ffn(output)
        return output
        

class Flatten_Head(nn.Module):
    def __init__(self, individual, n_vars, nf, target_window, head_dropout=0):
        super().__init__()
        
        self.individual = individual
        self.n_vars = n_vars
        
        if self.individual:
            self.linears = nn.ModuleList()
            self.dropouts = nn.ModuleList()
            self.flattens = nn.ModuleList()
            for i in range(self.n_vars):
                self.flattens.append(nn.Flatten(start_dim=-2))
                self.linears.append(nn.Linear(nf, target_window))
                self.dropouts.append(nn.Dropout(head_dropout))
        else:
            self.flatten = nn.Flatten(start_dim=-2)
            self.linear = nn.Linear(nf, target_window)
            self.dropout = nn.Dropout(head_dropout)
            
    def forward(self, x):                                 # x: [bs x nvars x d_model x patch_num]
        if self.individual:
            x_out = []
            for i in range(self.n_vars):
                z = self.flattens[i](x[:,i,:,:])          # z: [bs x d_model * patch_num]
                z = self.linears[i](z)                    # z: [bs x target_window]
                z = self.dropouts[i](z)
                x_out.append(z)
            x = torch.stack(x_out, dim=1)                 # x: [bs x nvars x target_window]
        else:
            x = self.flatten(x)
            x = self.linear(x)
            x = self.dropout(x)
        return x


class Transpose(nn.Module):
    def __init__(self, *dims, contiguous=False): 
        super().__init__()
        self.dims, self.contiguous = dims, contiguous
    def forward(self, x):
        if self.contiguous: return x.transpose(*self.dims).contiguous()
        else: return x.transpose(*self.dims)
