from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import torch

if TYPE_CHECKING:  # pragma: no cover
    from numpy.typing import NDArray


def _nan_to_num_np(x: np.ndarray) -> np.ndarray:
    return np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)


def corr_loss_and_dLtdv_numpy(
    output_vs: np.ndarray,
    target_corrcoef: np.ndarray,
    lr_start: int,
    lr_end: int,
) -> tuple[float, np.ndarray]:
    """
    Legacy numpy reference implementation.

    Returns:
      - mean_error: scalar float
      - dLtdv: (N_output, T_window) ndarray
    """
    lr_start = int(lr_start)
    lr_end = int(lr_end)
    output_corrcoef = np.corrcoef(output_vs[:, lr_start:lr_end])
    mean_error_mat = np.abs(target_corrcoef - output_corrcoef)
    mean_error_mat[np.isnan(mean_error_mat)] = 0.0
    mean_error = float(0.5 * np.mean(mean_error_mat**2))

    # Derivative: kept byte-for-byte compatible with the historical implementation
    # (including ddof choices and explicit loops).
    dLdcorr = np.array(-(target_corrcoef - output_corrcoef))  # (N_output, N_output)
    dLdcorr = _nan_to_num_np(dLdcorr)
    n_output = int(output_vs.shape[0])
    dcorrdxt = np.zeros((n_output, n_output, lr_end - lr_start), dtype=np.float64)
    for i in range(n_output):
        x = output_vs[i, lr_start:lr_end]
        mean_x = np.mean(x)
        std_x = np.std(x)
        for j in range(n_output):
            y = output_vs[j, lr_start:lr_end]
            mean_y = np.mean(y)
            std_y = np.std(y)
            cov_xy = np.mean((x - mean_x) * (y - mean_y))
            dcov_xydxt = ((y - mean_y) - np.mean(y - mean_y)) / (lr_end - lr_start)
            dstd_xdxt = ((x - mean_x) - np.mean(x - mean_x)) / (std_x * (lr_end - lr_start))
            dcorrxydxt = (dcov_xydxt * std_x * std_y - cov_xy * (dstd_xdxt * std_y)) / (std_x * std_y) ** 2
            dcorrdxt[i, j, :] = dcorrxydxt
    dLtdv = np.sum(dLdcorr[:, :, np.newaxis] * dcorrdxt, axis=1)  # (N_output, T_window)
    dLtdv = _nan_to_num_np(dLtdv).astype(np.float32, copy=False)
    return mean_error, dLtdv


def corr_loss_and_dLtdv_torch(
    output_vs: np.ndarray,
    target_corrcoef: np.ndarray,
    lr_start: int,
    lr_end: int,
    *,
    device: str = "cuda:0",
    eps: float = 1e-12,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Torch/GPU implementation of:
      - corrcoef-based mean_error (matches existing scalar loss)
      - dL/dv for TARGET_MODE='corr' (vectorized, avoids Python nested loops)

    Returns:
      - mean_error: scalar tensor
      - dLtdv: (N_output, T_window) tensor on `device`
    """
    lr_start = int(lr_start)
    lr_end = int(lr_end)
    if lr_end <= lr_start:
        raise ValueError(f"invalid lr window: [{lr_start}, {lr_end})")

    if device.startswith("cuda") and not torch.cuda.is_available():
        device = "cpu"
    dev = torch.device(device)

    X = torch.as_tensor(output_vs[:, lr_start:lr_end], dtype=torch.float32, device=dev)  # (N, T)
    T = int(X.shape[1])
    if T <= 1:
        mean_error = torch.tensor(0.0, dtype=torch.float32, device=dev)
        dLtdv = torch.zeros_like(X)
        return mean_error, dLtdv

    target = torch.as_tensor(target_corrcoef, dtype=torch.float32, device=dev)  # (N, N)

    # --- corrcoef (np.corrcoef-like, unbiased) for the loss/dLdcorr ---
    mean = X.mean(dim=1, keepdim=True)
    Xc = X - mean
    cov1 = (Xc @ Xc.t()) / float(T - 1)
    var1 = torch.diagonal(cov1, 0)
    std1 = torch.sqrt(torch.clamp(var1, min=0.0))  # (N,)
    denom1 = std1[:, None] * std1[None, :] + eps
    corr = cov1 / denom1
    corr = torch.nan_to_num(corr, nan=0.0, posinf=0.0, neginf=0.0)

    diff = torch.abs(target - corr)
    diff = torch.nan_to_num(diff, nan=0.0, posinf=0.0, neginf=0.0)
    mean_error = 0.5 * torch.mean(diff * diff)

    # dL/dcorr matches legacy sign: dLdcorr = -(target - corr) = (corr - target)
    dLdcorr = corr - target  # (N, N)
    dLdcorr = torch.nan_to_num(dLdcorr, nan=0.0, posinf=0.0, neginf=0.0)

    # --- derivative part: vectorized equivalent of historical python loops ---
    cov0 = (Xc @ Xc.t()) / float(T)  # (N, N), mean-based
    var0 = (Xc * Xc).mean(dim=1)  # (N,), mean-based
    std0 = torch.sqrt(torch.clamp(var0, min=0.0)) + eps  # (N,)

    W = dLdcorr / (std0[:, None] * std0[None, :])
    A = (W @ Xc) / float(T)

    S = torch.sum(dLdcorr * (cov0 / std0[None, :]), dim=1)  # (N,)
    B = Xc * (S / (float(T) * (std0**3)))[:, None]  # (N, T)

    dLtdv = A - B
    dLtdv = torch.nan_to_num(dLtdv, nan=0.0, posinf=0.0, neginf=0.0)
    return mean_error, dLtdv

