import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class VarTCNBlock(nn.Module):
    """
    Variable-TCN block:
      - depthwise 2D conv across (variables C, patches P) per embedding channel H
      - BN -> GELU -> pointwise expand (pw1) -> GELU -> pw2 -> dropout
      - residual add
    Input shape: [B, H, C, P]  (H channels, spatial dims = C x P)
    Output shape: same
    """
    def __init__(self, hidden_size, kernel_vars=7, kernel_patches=3, ff_ratio=4, dropout=0.1):
        super().__init__()
        self.hidden = hidden_size
        # depthwise spatial conv: in_channels = H, groups = H (per-channel spatial conv)
        self.dw = nn.Conv2d(in_channels=hidden_size, out_channels=hidden_size,
                            kernel_size=(kernel_vars, kernel_patches),
                            padding=(kernel_vars//2, kernel_patches//2),
                            groups=hidden_size, bias=False)
        self.bn = nn.BatchNorm2d(hidden_size)
        self.pw1 = nn.Conv2d(hidden_size, hidden_size * ff_ratio, kernel_size=1, bias=True)
        self.pw2 = nn.Conv2d(hidden_size * ff_ratio, hidden_size, kernel_size=1, bias=True)
        self.act = nn.GELU()
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        # x: [B, H, C, P]
        y = self.dw(x)
        y = self.bn(y)
        y = self.act(y)
        y = self.pw1(y)
        y = self.act(y)
        y = self.drop(y)
        y = self.pw2(y)
        y = self.drop(y)
        return x + y  # residual

EPS = 1e-6

class Model(nn.Module):
    """
    Vectorized VPNet model: patch encoding/decoding fully vectorized (no Python for-loops).
    """
    def __init__(self, configs):
        super(Model, self).__init__()
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.Linear = nn.Linear(self.seq_len, self.pred_len)
        self.patch_length = configs.slice_len
        self.middle_size = configs.middle_len
        self.hidden_size = configs.hidden_len
        self.slice_stride = configs.slice_stride
        self.encoder_dropout = configs.encoder_dropout

        # encoder / decoder (shared)
        self.encoder = nn.Sequential(
            nn.Linear(self.patch_length, self.middle_size),
            nn.LeakyReLU(),
            nn.Dropout(self.encoder_dropout),
            nn.Linear(self.middle_size, self.hidden_size),
            nn.LayerNorm(self.hidden_size)
        )
        self.decoder = nn.Sequential(
            nn.Linear(self.hidden_size, self.middle_size),
            nn.LeakyReLU(),
            nn.Dropout(self.encoder_dropout),
            nn.Linear(self.middle_size, self.patch_length),
        )

        self.num_patches = int(self.seq_len // self.patch_length)
        self.num_patches_p = int(self.pred_len // self.patch_length)

        self.var_tcn_blocks = getattr(configs, "var_tcn_blocks", 2)
        self.kernel_vars = getattr(configs, "kernel_vars", min(7, max(3, int(self.hidden_size//2))))
        self.kernel_patches = getattr(configs, "kernel_patches", 3)
        self.ff_ratio = getattr(configs, "var_ff_ratio", 4)
        self.var_dropout = getattr(configs, "var_dropout", self.encoder_dropout)
        self.use_revin = getattr(configs, "use_revin", False)

        # build VarTCN stack
        blocks = []
        for _ in range(self.var_tcn_blocks):
            blocks.append(VarTCNBlock(self.hidden_size, kernel_vars=self.kernel_vars,
                                      kernel_patches=self.kernel_patches,
                                      ff_ratio=self.ff_ratio, dropout=self.var_dropout))
        self.var_tcn = nn.Sequential(*blocks)

        # predictor
        in_dim = self.hidden_size * self.num_patches
        out_dim = self.hidden_size * self.num_patches_p
        ff = getattr(configs, "d_ff", max(256, in_dim // 2))
        self.fc_predictor = nn.Sequential(
            nn.Linear(in_dim, ff),
            nn.GELU(),
            nn.Dropout(self.encoder_dropout),
            nn.Linear(ff, out_dim),
        )

    def revin_norm(self, x):
        # x: [B, L, C]
        mean = x.mean(dim=1, keepdim=True)        # [B,1,C]
        std = x.var(dim=1, keepdim=True, unbiased=False).add(EPS).sqrt()  # [B,1,C]
        x_norm = (x - mean) / std
        return x_norm, mean, std

    def revin_denorm(self, x_norm, mean, std):
        return x_norm * std + mean

    def forward(self, x):
        """
        Input: x [B, L, C]
        Return: out [B, pred_len, C], slice, decoded_slice
        """
        seq_last = x[:, -1:, :].detach()
        # 1) normalization
        if self.use_revin:
            x_norm, mean, std = self.revin_norm(x)
            x_proc = x_norm
        else:
            x_proc = x - seq_last
            mean, std = None, None
        B = x_proc.shape[0]

        # x_proc: [B, L, C] -> permute to [B, C, L]
        x_proc = x_proc.permute(0, 2, 1).contiguous()  # [B, C, L]

        # for_enc: [B, C, S, p]
        for_enc = x_proc.unfold(-1, self.patch_length, self.slice_stride).contiguous()
        Bf, Cf, S, p = for_enc.shape  # Bf==B, Cf==C, S = number of sliding slices
        B, C, S, p = for_enc.shape

        slice = for_enc.view(B, C, S * p)

        # reshape to [B*C*S, p]
        patches_sliding = for_enc.view(B * C * S, p)
        encoded_sliding = self.encoder(patches_sliding)   # [B*C*S, H]
        H = encoded_sliding.shape[-1]
        encoded_sliding = encoded_sliding.view(B, C, S, H)  # [B, C, S, H]

        # decoded_slices (reconstruction of sliding patches)
        decoded_sliding = self.decoder(encoded_sliding.view(B * C * S, H))  # [B*C*S, p]
        decoded_sliding = decoded_sliding.view(B, C, S * p)  # [B, C, S*p]
        decoded_slice = decoded_sliding  


        # data_patches: [B, C, P, p]
        P = self.num_patches
        data_patches = x_proc.view(B, C, P, self.patch_length).contiguous()
        patches_pred = data_patches.view(B * C * P, self.patch_length)  # [B*C*P, p]
        encoded_patches = self.encoder(patches_pred)  # [B*C*P, H]
        enc_stack = encoded_patches.view(B, C, P, H)  # [B, C, P, H]

        # ---- Variable-TCN on (C, P) plane ----
        B_e, C_e, P_e, H_e = enc_stack.shape
        # permute to [B, H, C, P]
        var_in = enc_stack.permute(0, 3, 1, 2).contiguous()  # [B, H, C, P]
        var_out = self.var_tcn(var_in)                       # [B, H, C, P]
        # back to [B, C, P, H]
        enc_mixed = var_out.permute(0, 2, 3, 1).contiguous()  # [B, C, P, H]

        # ---- per-variable temporal prediction----
        Bm, Cm, Pm, Hm = enc_mixed.shape
        flat = enc_mixed.view(Bm * Cm, Hm * Pm)   # [B*C, H*P]
        pred_flat = self.fc_predictor(flat)       # [B*C, H*P_p]
        pred = pred_flat.view(Bm, Cm, self.num_patches_p, Hm)  # [B, C, P_p, H]

        # decode predicted patches
        Bp, Cp, Pp, Hp = pred.shape
        pred_reshaped = pred.view(Bp * Cp * Pp, Hp)  # [B*C*Pp, H]
        decoded_pred_flat = self.decoder(pred_reshaped)  # [B*C*Pp, p]
        decoded_pred = decoded_pred_flat.view(Bp, Cp, Pp, self.patch_length)  # [B, C, Pp, p]
        predictions = decoded_pred.view(B, C, Pp * self.patch_length)  # [B, C, pred_len]

        out = predictions.permute(0, 2, 1).contiguous()  # [B, pred_len, C]

        if self.use_revin:
            out = self.revin_denorm(out, mean, std)
        else:
            out = out + seq_last

        return out, slice, decoded_slice
