# layers/rlc.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional


def _safe_index(x, idx):
    # x: [B, L, C], idx: int (1-based lag)
    # return x[:, -idx, :] if idx <= L else zeros
    B, L, C = x.shape
    if idx <= 0:
        return torch.zeros((B, C), device=x.device, dtype=x.dtype)
    if idx <= L:
        return x[:, -idx, :]
    else:
        return torch.zeros((B, C), device=x.device, dtype=x.dtype)


class RLCRegressor(nn.Module):
    def __init__(
        self,
        c_in: int,
        k: int = 8,
        init_scale: float = 1e-2,
        use_mlp: bool = False,
        mlp_hidden: int = 64,
        lags: Optional[List[int]] = None,
        window_sizes: Optional[List[int]] = None,
        aux_pred_mode: str = "pooled",  # "pooled" or "residual_seq"
        aux_hidden: int = 64,
        orth_reg: bool = True,
    ):

        super().__init__()
        self.c_in = int(c_in)
        self.k = int(k)
        self.use_mlp = bool(use_mlp)
        self.orth_reg_flag = bool(orth_reg)

        # default lags and windows if not provided
        if lags is None:
            lags = [1, 24, 168] if self.c_in >= 1 else [1]
        self.lags = list(sorted(set(int(x) for x in lags if x >= 1)))

        if window_sizes is None:
            window_sizes = [3, 24, 168]
        self.window_sizes = list(sorted(set(int(x) for x in window_sizes if x >= 1)))

        # total lag-derived feature dim = C * (len(lags) + len(window_sizes) + 1(pool current mean))
        self.feat_per_channel = len(self.lags) + len(self.window_sizes) + 1
        self.feat_dim = self.c_in * self.feat_per_channel

        # projection from lag features to latent z
        if self.use_mlp:
            self.project_net = nn.Sequential(
                nn.Linear(self.feat_dim, mlp_hidden),
                nn.ReLU(),
                nn.Linear(mlp_hidden, self.k)
            )
            nn.init.normal_(self.project_net[-1].weight, std=init_scale)
            nn.init.zeros_(self.project_net[-1].bias)
            self.W = None
        else:
            # linear: W shape [feat_dim, k]
            self.W = nn.Parameter(torch.randn(self.feat_dim, self.k) * init_scale)
            self.project_net = None

        # future projection head (for correlation loss; maps pooled future -> k)
        self.future_proj = nn.Linear(self.c_in, self.k)
        nn.init.normal_(self.future_proj.weight, std=init_scale)
        nn.init.zeros_(self.future_proj.bias)

        # auxiliary predictor from z -> future pooled (C) or to residual sequence if chosen
        self.aux_pred_mode = aux_pred_mode
        if aux_pred_mode == "pooled":
            self.aux_head = nn.Sequential(
                nn.Linear(self.k, aux_hidden),
                nn.ReLU(),
                nn.Linear(aux_hidden, self.c_in)
            )
        else:
            # keep simple: predict pooled residual as fallback
            self.aux_head = nn.Sequential(
                nn.Linear(self.k, aux_hidden),
                nn.ReLU(),
                nn.Linear(aux_hidden, self.c_in)
            )

    def _compute_lag_features(self, x_hist: torch.Tensor):

        if x_hist.dim() != 3:
            B = x_hist.size(0)
            D = x_hist.size(-1)
            x_hist = x_hist.view(B, -1, D)

        B, L, C = x_hist.shape
        # pool current mean (mean over last min(L, max(window_sizes)) steps)
        curr_mean = x_hist.mean(dim=1)  # [B, C]

        lag_feats = []
        # explicit lag values (single-step at lag)
        for lag in self.lags:
            lf = _safe_index(x_hist, lag)  # [B, C]
            lag_feats.append(lf)

        # window means
        for w in self.window_sizes:
            if w <= L:
                wf = x_hist[:, -w:, :].mean(dim=1)  # [B, C]
            else:
                # if window > L, fallback to whole mean
                wf = x_hist.mean(dim=1)
            lag_feats.append(wf)

        # final feature list: [curr_mean, lag1, lag24, ..., window1_mean, ...]
        all_feats = [curr_mean] + lag_feats  # list of [B,C]
        # concatenate per channel: make shape [B, C * feat_per_channel]
        stacked = torch.cat(all_feats, dim=-1)  # [B, C * feat_per_channel]?? careful: concat along channel dimension
        # Actually all_feats each is [B, C], concat along last dim yields [B, C*(n_feats)]
        return stacked  # [B, feat_dim]

    def project(self, x_hist: torch.Tensor, mode: str = "mean"):
        """
        Create latent representation z from history using lag features.
        x_hist: [B, L, C]
        returns: z [B, k]
        """
        feat = self._compute_lag_features(x_hist)  # [B, feat_dim]
        if self.use_mlp:
            z = self.project_net(feat)
        else:
            z = feat @ self.W
        return z

    def correlation_loss(self, z: torch.Tensor, future_y: torch.Tensor, reduce: str = "sum"):

        if z is None:
            raise ValueError("z is None")
        if z.dim() != 2:
            raise ValueError("z must be [B,k]")

        B = z.size(0)
        # pool future into [B, C]
        fy = future_y
        if fy is None:
            raise ValueError("future_y is None")
        if fy.dim() == 3:
            y_pool = fy.mean(dim=1)
        elif fy.dim() == 2:
            y_pool = fy
        else:
            y_pool = fy.view(fy.size(0), -1)

        # align y_pool dim to c_in (truncate or pad)
        if y_pool.size(-1) != self.c_in:
            if y_pool.size(-1) > self.c_in:
                y_pool = y_pool[:, :self.c_in]
            else:
                pad = torch.zeros(y_pool.size(0), self.c_in - y_pool.size(-1), device=y_pool.device, dtype=y_pool.dtype)
                y_pool = torch.cat([y_pool, pad], dim=-1)

        y_proj = self.future_proj(y_pool)  # [B, k]

        # compute Pearson r per dim
        zc = z - z.mean(dim=0, keepdim=True)
        yc = y_proj - y_proj.mean(dim=0, keepdim=True)
        cov = (zc * yc).sum(dim=0) / (B - 1.0 + 1e-12)
        std_z = zc.std(dim=0, unbiased=True)
        std_y = yc.std(dim=0, unbiased=True)
        denom = std_z * std_y + 1e-8
        r = cov / denom
        r2 = r * r
        if reduce == "sum":
            return - r2.sum()
        elif reduce == "mean":
            return - r2.mean()
        else:
            return - r2  # vector

    def orthogonality_loss(self):

        if self.W is None:
            return torch.tensor(0.0, device=next(self.parameters()).device)
        WtW = torch.matmul(self.W.t(), self.W)  # [k, k]
        I = torch.eye(self.k, device=WtW.device, dtype=WtW.dtype)
        return ((WtW - I) ** 2).sum()

    def aux_predict(self, z: torch.Tensor):

        return self.aux_head(z)

    def compute_losses(self, x_hist: torch.Tensor, future_y: torch.Tensor, mode: str = "mean"):

        z = self.project(x_hist, mode=mode)
        losses = {}
        if future_y is not None:
            losses['rlc_corr'] = self.correlation_loss(z, future_y, reduce='sum')
            # auxiliary target: pooled future mean
            if future_y.dim() == 3:
                y_pool = future_y.mean(dim=1)
            else:
                y_pool = future_y
            # align dims
            if y_pool.size(-1) != self.c_in:
                if y_pool.size(-1) > self.c_in:
                    y_pool = y_pool[:, :self.c_in]
                else:
                    pad = torch.zeros(y_pool.size(0), self.c_in - y_pool.size(-1), device=y_pool.device, dtype=y_pool.dtype)
                    y_pool = torch.cat([y_pool, pad], dim=-1)
            y_hat = self.aux_predict(z)
            losses['rlc_aux'] = F.mse_loss(y_hat, y_pool)
            losses['rlc_orth'] = self.orthogonality_loss() if self.orth_reg_flag else torch.tensor(0.0, device=y_hat.device)
        return z, losses
