import torch
import torch.nn as nn
import math
from einops import rearrange

from layers.SelfAttention_Family import AttentionLayer, FullAttention


class FeedForward(nn.Module):
    def __init__(self, embedding_size: int, d_hidden: int = 512):
        super(FeedForward, self).__init__()

        self.linear_1 = torch.nn.Linear(embedding_size, d_hidden)
        self.linear_2 = torch.nn.Linear(d_hidden, embedding_size)
        self.activation = torch.nn.GELU()

    def forward(self, x):
        x = self.linear_1(x)
        x = self.activation(x)
        x = self.linear_2(x)

        return x


class Encoder(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        mha: AttentionLayer,
        d_hidden: int,
        dropout: float = 0,
        channel_wise=False,
    ):
        super(Encoder, self).__init__()

        self.channel_wise = channel_wise
        if self.channel_wise:
            self.conv = torch.nn.Conv1d(
                in_channels=embedding_size,
                out_channels=embedding_size,
                kernel_size=1,
                stride=1,
                padding=0,
                padding_mode="reflect",
            )
        self.MHA = mha
        self.feedforward = FeedForward(embedding_size=embedding_size, d_hidden=d_hidden)
        self.dropout = torch.nn.Dropout(p=dropout)
        self.layerNormal_1 = torch.nn.LayerNorm(embedding_size)
        self.layerNormal_2 = torch.nn.LayerNorm(embedding_size)

    def forward(self, x):
        residual = x
        q = residual
        if self.channel_wise:
            x_r = self.conv(x.permute(0, 2, 1)).transpose(1, 2)
            k = x_r
            v = x_r
        else:
            k = residual
            v = residual
        x, score = self.MHA(q, k, v, attn_mask=None)
        x = self.dropout(x)
        x = self.layerNormal_1(x + residual)

        residual = x
        x = self.feedforward(residual)
        x = self.dropout(x)
        x = self.layerNormal_2(x + residual)

        return x, score


class MultiPatchFormer(nn.Module):
    def __init__(self, configs):
        super(MultiPatchFormer, self).__init__()
        self.task_name = configs['task_name']
        self.seq_len = configs['seq_len']
        self.pred_len = configs['pred_len']
        self.output_attention = configs['output_attention']
        self.d_channel = configs['enc_in']
        self.N = configs['e_layers']
        # Embedding
        self.embedding_size = configs['embedding_size']
        self.d_hidden = configs['embedding_size']
        self.n_heads = configs['n_heads']
        self.mask = True
        self.dropout = configs['dropout']

        self.stride1 = 8
        self.patch_len1 = 8
        self.stride2 = 8
        self.patch_len2 = 16
        self.stride3 = 7
        self.patch_len3 = 24
        self.stride4 = 6
        self.patch_len4 = 32
        self.patch_num1 = int((self.seq_len - self.patch_len2) // self.stride2) + 2
        self.padding_patch_layer1 = nn.ReplicationPad1d((0, self.stride1))
        self.padding_patch_layer2 = nn.ReplicationPad1d((0, self.stride2))
        self.padding_patch_layer3 = nn.ReplicationPad1d((0, self.stride3))
        self.padding_patch_layer4 = nn.ReplicationPad1d((0, self.stride4))

        self.shared_MHA = nn.ModuleList(
            [
                AttentionLayer(
                    FullAttention(mask_flag=self.mask),
                    d_model=self.embedding_size,
                    n_heads=self.n_heads,
                )
                for _ in range(self.N)
            ]
        )

        self.shared_MHA_ch = nn.ModuleList(
            [
                AttentionLayer(
                    FullAttention(mask_flag=self.mask),
                    d_model=self.embedding_size,
                    n_heads=self.n_heads,
                )
                for _ in range(self.N)
            ]
        )

        self.encoder_list = nn.ModuleList(
            [
                Encoder(
                    embedding_size=self.embedding_size,
                    mha=self.shared_MHA[ll],
                    d_hidden=self.d_hidden,
                    dropout=self.dropout,
                    channel_wise=False,
                )
                for ll in range(self.N)
            ]
        )

        self.encoder_list_ch = nn.ModuleList(
            [
                Encoder(
                    embedding_size=self.embedding_size,
                    mha=self.shared_MHA_ch[0],
                    d_hidden=self.d_hidden,
                    dropout=self.dropout,
                    channel_wise=True,
                )
                for ll in range(self.N)
            ]
        )

        pe = torch.zeros(self.patch_num1, self.embedding_size)
        for pos in range(self.patch_num1):
            for i in range(0, self.embedding_size, 2):
                wavelength = 10000 ** ((2 * i) / self.embedding_size)
                pe[pos, i] = math.sin(pos / wavelength)
                pe[pos, i + 1] = math.cos(pos / wavelength)
        pe = pe.unsqueeze(0)  # add a batch dimention to your pe matrix
        self.register_buffer("pe", pe)

        self.embedding_channel = nn.Conv1d(
            in_channels=self.embedding_size * self.patch_num1,
            out_channels=self.embedding_size,
            kernel_size=1,
        )

        self.embedding_patch_1 = torch.nn.Conv1d(
            in_channels=1,
            out_channels=self.embedding_size // 4,
            kernel_size=self.patch_len1,
            stride=self.stride1,
        )
        self.embedding_patch_2 = torch.nn.Conv1d(
            in_channels=1,
            out_channels=self.embedding_size // 4,
            kernel_size=self.patch_len2,
            stride=self.stride2,
        )
        self.embedding_patch_3 = torch.nn.Conv1d(
            in_channels=1,
            out_channels=self.embedding_size // 4,
            kernel_size=self.patch_len3,
            stride=self.stride3,
        )
        self.embedding_patch_4 = torch.nn.Conv1d(
            in_channels=1,
            out_channels=self.embedding_size // 4,
            kernel_size=self.patch_len4,
            stride=self.stride4,
        )

        self.out_linear_1 = torch.nn.Linear(self.embedding_size, self.pred_len // 8)
        self.out_linear_2 = torch.nn.Linear(
            self.embedding_size + self.pred_len // 8, self.pred_len // 8
        )
        self.out_linear_3 = torch.nn.Linear(
            self.embedding_size + 2 * self.pred_len // 8, self.pred_len // 8
        )
        self.out_linear_4 = torch.nn.Linear(
            self.embedding_size + 3 * self.pred_len // 8, self.pred_len // 8
        )
        self.out_linear_5 = torch.nn.Linear(
            self.embedding_size + self.pred_len // 2, self.pred_len // 8
        )
        self.out_linear_6 = torch.nn.Linear(
            self.embedding_size + 5 * self.pred_len // 8, self.pred_len // 8
        )
        self.out_linear_7 = torch.nn.Linear(
            self.embedding_size + 6 * self.pred_len // 8, self.pred_len // 8
        )
        self.out_linear_8 = torch.nn.Linear(
            self.embedding_size + 7 * self.pred_len // 8,
            self.pred_len - 7 * (self.pred_len // 8),
        )

        self.remap = torch.nn.Linear(self.embedding_size, self.seq_len)

    def forecast(self, x_enc):
        # Normalization
        means = x_enc.mean(1, keepdim=True).detach()
        x_enc = x_enc - means
        stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
        x_enc /= stdev

        # Multi-scale embedding
        x_i = x_enc.permute(0, 2, 1)

        x_i_p1 = x_i
        x_i_p2 = self.padding_patch_layer2(x_i)
        x_i_p3 = self.padding_patch_layer3(x_i)
        x_i_p4 = self.padding_patch_layer4(x_i)
        encoding_patch1 = self.embedding_patch_1(
            rearrange(x_i_p1, "b c l -> (b c) l").unsqueeze(-1).permute(0, 2, 1)
        ).permute(0, 2, 1)
        encoding_patch2 = self.embedding_patch_2(
            rearrange(x_i_p2, "b c l -> (b c) l").unsqueeze(-1).permute(0, 2, 1)
        ).permute(0, 2, 1)
        encoding_patch3 = self.embedding_patch_3(
            rearrange(x_i_p3, "b c l -> (b c) l").unsqueeze(-1).permute(0, 2, 1)
        ).permute(0, 2, 1)
        encoding_patch4 = self.embedding_patch_4(
            rearrange(x_i_p4, "b c l -> (b c) l").unsqueeze(-1).permute(0, 2, 1)
        ).permute(0, 2, 1)
        import ipdb;ipdb.set_trace()
        encoding_patch = (
            torch.cat(
                (encoding_patch1, encoding_patch2, encoding_patch3, encoding_patch4),
                dim=-1,
            )
            + self.pe
        )
        # Temporal encoding
        for i in range(self.N):
            encoding_patch = self.encoder_list[i](encoding_patch)[0]

        # Channel-wise encoding
        x_patch_c = rearrange(
            encoding_patch, "(b c) p d -> b c (p d)", b=x_enc.shape[0], c=self.d_channel
        )
        x_ch = self.embedding_channel(x_patch_c.permute(0, 2, 1)).transpose(
            1, 2
        )  # [b c d]

        encoding_1_ch = self.encoder_list_ch[0](x_ch)[0]

        # Semi Auto-regressive
        forecast_ch1 = self.out_linear_1(encoding_1_ch)
        forecast_ch2 = self.out_linear_2(
            torch.cat((encoding_1_ch, forecast_ch1), dim=-1)
        )
        forecast_ch3 = self.out_linear_3(
            torch.cat((encoding_1_ch, forecast_ch1, forecast_ch2), dim=-1)
        )
        forecast_ch4 = self.out_linear_4(
            torch.cat((encoding_1_ch, forecast_ch1, forecast_ch2, forecast_ch3), dim=-1)
        )
        forecast_ch5 = self.out_linear_5(
            torch.cat(
                (encoding_1_ch, forecast_ch1, forecast_ch2, forecast_ch3, forecast_ch4),
                dim=-1,
            )
        )
        forecast_ch6 = self.out_linear_6(
            torch.cat(
                (
                    encoding_1_ch,
                    forecast_ch1,
                    forecast_ch2,
                    forecast_ch3,
                    forecast_ch4,
                    forecast_ch5,
                ),
                dim=-1,
            )
        )
        forecast_ch7 = self.out_linear_7(
            torch.cat(
                (
                    encoding_1_ch,
                    forecast_ch1,
                    forecast_ch2,
                    forecast_ch3,
                    forecast_ch4,
                    forecast_ch5,
                    forecast_ch6,
                ),
                dim=-1,
            )
        )
        forecast_ch8 = self.out_linear_8(
            torch.cat(
                (
                    encoding_1_ch,
                    forecast_ch1,
                    forecast_ch2,
                    forecast_ch3,
                    forecast_ch4,
                    forecast_ch5,
                    forecast_ch6,
                    forecast_ch7,
                ),
                dim=-1,
            )
        )

        final_forecast = torch.cat(
            (
                forecast_ch1,
                forecast_ch2,
                forecast_ch3,
                forecast_ch4,
                forecast_ch5,
                forecast_ch6,
                forecast_ch7,
                forecast_ch8,
            ),
            dim=-1,
        ).permute(0, 2, 1)

        # De-Normalization
        dec_out = final_forecast * (
            stdev[:, 0].unsqueeze(1).repeat(1, self.pred_len, 1)
        )
        dec_out = dec_out + (means[:, 0].unsqueeze(1).repeat(1, self.pred_len, 1))
        return dec_out

    def forward(self, x_enc):
        x_enc = x_enc.unsqueeze(-1)
        dec_out = self.forecast(x_enc)
        dec_out = dec_out.squeeze(-1)
        self.dec_out = dec_out[:, -self.pred_len:]
        return self.dec_out