import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

from layers.Transformer_EncDec import TSEncoder, t_layer, d_layer, c_layer
from layers.SelfAttention_Family import TSMixer, ResAttention
from layers.Embed import PatchEmbed
import numpy as np


class Model(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()

        self.c_in = configs.enc_in
        self.period = configs.period
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.num_p = self.seq_len // self.period
        if configs.num_p is None:
            configs.num_p = self.num_p

        # long-term and short-term embedding
        self.embedding = PatchEmbed(configs, num_p=self.num_p)
        # Encoder-only architecture

        layers = self.layers_init(configs)
        self.encoder = TSEncoder(layers)

        self.decoder = nn.Sequential(
            nn.Flatten(start_dim=-2),
            nn.Linear(configs.num_p * configs.d_model, configs.pred_len, bias=False)
        )

    def layers_init(self, configs):

        layer_p1 = [t_layer(
            TSMixer(ResAttention(),configs.d_model, configs.n_heads),
            configs.d_model, configs.d_ff, stable=True, num_p=self.num_p,
            dropout=configs.dropout, activation=configs.activation
        ) for i in range(configs.t_layers)]
        layer_p2 = t_layer(
            TSMixer(ResAttention(), configs.d_model, configs.n_heads),
            configs.d_model, configs.d_ff, stable=False, num_p=configs.num_p,
            dropout=configs.dropout, activation=configs.activation
        )
        layer_d = d_layer(
            TSMixer(ResAttention(), configs.d_model, configs.n_heads),
            configs.d_model, configs.d_ff, stable=False,
            in_p=self.num_p, out_p=configs.num_p,
            dropout=configs.dropout, activation=configs.activation
        )
        layer_c_full = [c_layer(
            TSMixer(ResAttention(), configs.d_model, configs.n_heads),
            configs.d_model, configs.d_ff, stable=False, enc_in=self.c_in,
            dropout=configs.dropout, activation=configs.activation, axial=False
        ) for i in range(configs.e_layers - 2)]

        layer_c_axial = [c_layer(
            TSMixer(ResAttention(), configs.d_model, configs.n_heads),
            configs.d_model, configs.d_ff, stable=False, enc_in=self.c_in,
            dropout=configs.dropout, activation=configs.activation
        ) for i in range(configs.e_layers - 2)]

        layer_c = layer_c_axial if self.c_in > 100 else layer_c_full
        return [*layer_p1, layer_d, layer_p2, *layer_c] if configs.num_p > 1 \
            else [*layer_p1, layer_d, *layer_c]

    def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
        if x_mark_enc is None:
            x_mark_enc = torch.zeros((*x_enc.shape[:-1], 4), device=x_enc.device)

        mean, std = x_enc.mean(1, keepdim=True).detach(), x_enc.std(1, keepdim=True).detach()
        x_enc = (x_enc - mean) / (std + 1e-5)

        x_enc = self.embedding(x_enc, x_mark_enc)
        enc_out = self.encoder(x_enc)[0][:, :self.c_in, ...]
        dec_out = self.decoder(enc_out).transpose(-1, -2)

        return dec_out * std + mean

    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
        dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
        return dec_out[:, -self.pred_len:, :]  # [B, L, D]