import torch
import torch.nn as nn
from einops import rearrange

from src.modules.llm_config import LLMConfig
from src.modules.encoder import Encoder
from src.modules.decoder import Decoder
from src.modules.revin import RevIN


class Sentinel(nn.Module):
    def __init__(self, config: LLMConfig):
        super().__init__()

        self.config = config

        self.patch_num = (config.seq_len - config.patch_size) // config.stride + 1
        self.patch_num += 1

        self.revin_layer = RevIN(config.enc_in)

        self.padding_patch_layer = nn.ReplicationPad1d((0, config.stride))
        self.in_layer_embed = nn.Linear(config.patch_size, config.d_model)
        self.pe = nn.Parameter(torch.randn(self.patch_num, config.d_model) * 1e-2)

        self.encoder = Encoder(config)
        self.decoder = Decoder(config)

        self.out_layer = nn.Linear(
            config.d_model*self.patch_num,
            config.pred_len
        )

    def forward(self, x, x_mark=None):

        B, T, C = x.shape

        x = self.revin_layer(x, 'norm')

        if x_mark is not None:
            x = torch.cat((x, x_mark), axis=2)
            B_M, T_M, C_M = x_mark.shape

        x = rearrange(x, 'b t c -> b c t')

        x = self.padding_patch_layer(x)
        # B x C x T -> B x C x PATCH_NUM x PATCH_SIZE
        x = x.unfold(
            dimension=-1,
            size=self.config.patch_size,
            step=self.config.stride
        )

        # B x C x PATCH_NUM x D_MODEL
        x = self.in_layer_embed(x) + self.pe

        x_enc = rearrange(x, 'b c pn d_model -> b pn c d_model')

        enc_output = self.encoder(x_enc)

        # B x PN x C x D_MODEL
        enc_output = rearrange(enc_output, "b pn c d_model -> b c pn d_model")

        dec_out = self.decoder(x, enc_output)

        if x_mark is not None:
            dec_out = dec_out.reshape(B, C+C_M, -1)
        else:
            dec_out = dec_out.reshape(B, C, -1)

        dec_out = self.out_layer(dec_out)

        dec_out = rearrange(
            dec_out,
            'b c pred_len -> b pred_len c'
        )

        if x_mark is not None:
            dec_out = self.revin_layer(dec_out[:, :, :C], 'denorm')
        else:
            dec_out = self.revin_layer(dec_out[:, :, :], 'denorm') # :C], 'denorm')

        return dec_out
