import torch
from torch import nn
from layers.RevIN import RevIN
from layers.Olivia_EncDec import TSTEncoder, TowerEncoder, PretrainHead
from layers.pos_encoding import positional_encoding



class LearnableOrthoTrans1D(nn.Module):
    def __init__(self, T, K=None, eps=1e-6, keep_k=None):
        super().__init__()
        self.T = T
        self.K = K if K is not None else T
        self.eps = eps
        self.keep_k = keep_k
        self.v = nn.Parameter(torch.randn(self.K, self.T) * 0.02)

    def _build_Q(self):
        Q = torch.eye(self.T, device=self.v.device, dtype=self.v.dtype)
        for k in range(self.K):
            v = self.v[k]
            v = v / (v.norm(p=2) + self.eps)
            H = torch.eye(self.T, device=v.device, dtype=v.dtype) - 2.0 * torch.outer(v, v)
            Q = H @ Q
        return Q


    def encode(self, x):
        Q = self._build_Q()
        z = x.squeeze(-1) @ Q.t()
        if self.keep_k is not None:
            k = self.keep_k
            z = torch.cat([z[:, :k], torch.zeros_like(z[:, k:])], dim=1)

        z = z.unsqueeze(-1)

        cache = {
            "Q": Q,
            "keep_k": self.keep_k
        }
        return z, cache

    def decode(self, z, cache):
        if cache is None:
            Q = self._build_Q()
        else:
            Q = cache["Q"]

        x0_hat = z.squeeze(-1) @ Q
        x0_hat = x0_hat.unsqueeze(-1)

        return x0_hat


    def forward(self, x, return_cache=False):
        z, cache = self.encode(x)
        return (z, cache) if return_cache else z


class Model(nn.Module):
    def __init__(self, configs):

        super().__init__()

        assert configs.head_type in ['pretrain', 'prediction'], 'head type should be either pretrain or prediction'
        head_dropout:float = 0.2
        n_heads:int = 16
        d_ff:int = 256
        norm:str = 'RMSNorm'
        attn_dropout:float = 0.
        dropout:float = 0.
        act:str = "silu"
        res_attention:bool = True
        pre_norm:bool = True
        store_attn:bool = False
        pe:str = 'zeros'
        learn_pe:bool = True
        self.freq_num:int = 4
        self.n_vars = configs.c_in
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.label_len = configs.label_len
        self.patch_len = configs.patch_len
        self.stride = configs.stride
        self.d_model = configs.d_model
        self.e_layers = configs.e_layers
        self.d_layers = configs.d_layers
        self.revin_layer = RevIN(self.n_vars, affine=True)
        self.ortho_trans = LearnableOrthoTrans1D(T=self.seq_len, K=self.seq_len//2)

        self.patch_num = (max(self.seq_len, self.patch_len) - self.patch_len) // self.stride + 1
        tgt_len = self.patch_len  + self.stride * (self.patch_num - 1)
        self.s_begin = self.seq_len - tgt_len
        self.W_pos = positional_encoding(pe, learn_pe, self.patch_num * self.n_vars, self.d_model)
        self.dropout = nn.Dropout(dropout)
        self.patch_embed_freq = nn.Linear(int(self.patch_len/2)+1, int(self.d_model/2)+1).to(torch.cfloat)
        self.encoder = TSTEncoder(self.d_model, n_heads=n_heads, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout,
                                  dropout=dropout, pre_norm=pre_norm, activation=act, res_attention=res_attention,
                                  n_layers=self.e_layers, store_attn=store_attn)
        self.decoder = TowerEncoder(self.d_model, n_heads=n_heads, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout,
                                dropout=dropout, pre_norm=pre_norm, activation=act, res_attention=res_attention,
                                n_layers=self.d_layers, store_attn=store_attn)
        self.head = PretrainHead(self.d_model, self.patch_len, head_dropout)
       

   
    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
        x = self.revin_layer(x_enc, 'norm')
        x, cache = self.ortho_trans.encode(x)
        x = x[:, self.s_begin:, :]
        x = x.unfold(dimension=1, size=self.patch_len, step=self.stride)
        bs, patch_num, n_vars, patch_len = x.shape
        x_fft = torch.fft.rfft(x, dim=-1)
        x_fft = self.patch_embed_freq(x_fft)
        x = torch.fft.irfft(x_fft, dim=-1, n=self.d_model)
        x = x.transpose(1, 2)
        u = torch.reshape(x, (-1, n_vars * patch_num, self.d_model))
        u = self.dropout(u + self.W_pos)
        x = self.encoder(u)
        x = torch.reshape(x, (-1, n_vars, patch_num, self.d_model))
        x = x.permute(0, 1, 3, 2)
        x_vec = x
        x = self.decoder(x)
        x = self.head(x)
        x = x.reshape(bs, patch_num * patch_len, n_vars)
        x = self.ortho_trans.decode(x, cache)
        x = self.revin_layer(x, 'denorm')
        return x, x_vec
