import torch
import torch.nn as nn
import math
import numpy as np
import torch.nn.init as init
import torch.nn.functional as F
from .Attention_Module import Attention, AttentionLayer
from .Frequency_mask_router import Frequency_Mask_Router

class Transformer(nn.Module):
    def __init__(self, mask, win_size, enc_in, c_out, d_model=512, n_heads=8, e_layers=3, d_ff=512,
                 dropout=0, activation='gelu'):
        super(Transformer, self).__init__()
        self.q_embedding = DataEmbedding(enc_in, d_model, dropout)
        self.embedding = DataEmbedding(enc_in, d_model, dropout)
        self.encoder = Encoder(
            [
                EncoderLayer(
                    AttentionLayer(
                        Attention(win_size, mask, attention_dropout=dropout),
                        d_model, n_heads),
                    d_model,
                    d_ff,
                    dropout=dropout,
                    activation=activation
                ) for l in range(e_layers)
            ],
            n_heads,
            norm_layer=torch.nn.LayerNorm(d_model)
        )
        self.freq_mask_process = Frequency_Mask_Router(5, win_size)
        self.projection = nn.Linear(d_model, c_out, bias=True)
        
    def forward(self, x):
        freq_masked_x, _, freq_mask = self.freq_mask_process(x)
        q = self.q_embedding(freq_masked_x)
        emb_x = self.embedding(x)
        enc_out_PTM, attn_grah1 = self.encoder(emb_x, emb_x, emb_x, self_attn=True)
        enc_out_FRM, attn_grah2 = self.encoder(q, enc_out_PTM, enc_out_PTM)
        rec_PTM = self.projection(enc_out_PTM)
        rec_FRM = self.projection(enc_out_FRM)
        return rec_PTM, rec_FRM, attn_grah1, attn_grah2

class EncoderLayer(nn.Module):
    def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, q, k, v, attn_mask=None):
        new_x, attn = self.attention(
            q, k, v,
            attn_mask=attn_mask
        )
        q = q + self.dropout(new_x)
        y = q = self.norm1(q)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))
        return self.norm2(q + y), attn


class Encoder(nn.Module):
    def __init__(self, attn_layers, n_heads, norm_layer=None):
        super(Encoder, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.norm = norm_layer
        self.n_heads = n_heads

    def forward(self, q, k, v, attn_mask=None, self_attn=False):
        attn_list = []
        for attn_layer in self.attn_layers:
            if not self_attn:
                q, attn = attn_layer(q, k, v, attn_mask=attn_mask)
                attn_list.append(attn)
            else:
                q, attn = attn_layer(q, q, q, attn_mask=attn_mask)
                attn_list.append(attn)
        if self.norm is not None:
            q = self.norm(q)
        return q, attn_list


class DataEmbedding(nn.Module):
    def __init__(self, c_in, d_model, dropout=0.0):
        super(DataEmbedding, self).__init__()

        self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
        self.position_embedding = PositionalEmbedding(d_model=d_model)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        x = self.value_embedding(x) + self.position_embedding(x)
        return self.dropout(x)


class TokenEmbedding(nn.Module):
    def __init__(self, c_in, d_model):
        super(TokenEmbedding, self).__init__()
        padding = 1 if torch.__version__ >= '1.5.0' else 2
        self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
                                   kernel_size=3, padding=padding, padding_mode='circular', bias=False)
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')

    def forward(self, x):
        x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
        return x


class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEmbedding, self).__init__()
        pe = torch.zeros(max_len, d_model).float()
        pe.require_grad = False

        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return self.pe[:, :x.size(1)]
    
if __name__ == '__main__':
    input = torch.rand(32,100,5)
    model = Transformer(
            mask=False,
            win_size=100,
            enc_in=5,
            c_out=5,
            e_layers=3,
            q_cat=True
        )
    
    out = torch.zeros_like(input)
    out = model(input)
    for i in out:
        print(i.shape)