"""Dual-space RFM (Recursive Feature Machine) for high-dimensional features.

Simplified from rfm_guidance/dual_rfm.py

Memory-efficient RFM training that computes guidance directions from diffusion
model activations using the "dual trick" to avoid O(D²) memory.
"""

import torch
import numpy as np
from typing import Tuple, Optional
from dataclasses import dataclass


@dataclass
class DualRFMResult:
    """Result of dual-space RFM training.

    Attributes:
        eigenvectors: [D, k] top-k AGOP eigenvectors
        eigenvalues: [k] corresponding eigenvalues
        mean: [D] feature mean
        std: [D] feature std
        best_auc: Validation AUC (-1 if no validation)
        best_iter: Iteration with best AUC
        sign: +1 or -1 (multiply eigenvectors by this!)
        pearson_corr: Pearson correlation
    """
    eigenvectors: torch.Tensor
    eigenvalues: torch.Tensor
    mean: np.ndarray
    std: np.ndarray
    best_auc: float
    best_iter: int
    sign: int = 1
    pearson_corr: float = 0.0


def determine_direction_sign(
    X: np.ndarray,
    y: np.ndarray,
    direction: np.ndarray,
    mean: np.ndarray,
    std: np.ndarray,
) -> Tuple[int, float]:
    """Determine the sign that makes the direction point toward the target class.

    Uses Pearson correlation.
    """
    from scipy.stats import pearsonr

    X_norm = (X - mean) / std
    projections = X_norm @ direction
    corr, _ = pearsonr(projections, y)

    return (1 if corr >= 0 else -1), corr


class LaplaceKernelGradients:
    """Compute Laplace kernel gradients for dual RFM."""

    def __init__(self, bandwidth: float = 100.0, eps: float = 1e-10):
        self.bandwidth = bandwidth
        self.eps = eps

    def compute_grads(
        self,
        x: torch.Tensor,
        z: torch.Tensor,
        xm: torch.Tensor,
        zm: torch.Tensor,
        coefs: torch.Tensor,
    ) -> torch.Tensor:
        """Compute function gradients using Mahalanobis distance."""
        # Mahalanobis distances
        xm_norm_sq = (xm * x).sum(dim=1)
        zm_norm_sq = (zm * z).sum(dim=1)
        dists_sq = xm_norm_sq.unsqueeze(1) - 2 * xm @ z.T + zm_norm_sq.unsqueeze(0)
        dists = dists_sq.clamp(min=0).sqrt()

        # Kernel values
        kernel_mat = torch.exp(-dists / self.bandwidth)

        # Gradient factor
        dists_safe = dists.clamp(min=self.eps)
        grad_factor = kernel_mat / (self.bandwidth * dists_safe)
        grad_factor = grad_factor * (dists >= self.eps).float()

        # Weighted sum
        weighted_factor = coefs.unsqueeze(1) * grad_factor
        grads = weighted_factor.sum(dim=0).unsqueeze(1) * zm
        grads = grads - (weighted_factor.T @ xm)

        # Negate to match xrfm sign convention
        return -grads


class DualRFM:
    """Dual-space RFM for high-dimensional features.

    Memory: O(N²) instead of O(D²)

    Args:
        bandwidth: Laplace kernel bandwidth (default: 100.0)
        reg: KRR regularization (default: 1e-3)
        center_grads: Center gradients before AGOP (default: True)
        verbose: Print progress (default: True)
    """

    def __init__(
        self,
        bandwidth: float = 100.0,
        reg: float = 1e-3,
        center_grads: bool = True,
        verbose: bool = True,
    ):
        self.bandwidth = bandwidth
        self.reg = reg
        self.center_grads = center_grads
        self.verbose = verbose
        self.kernel_grads = LaplaceKernelGradients(bandwidth=bandwidth)

    def _log(self, msg: str):
        if self.verbose:
            print(msg)

    def _mahalanobis_transform(
        self,
        X: torch.Tensor,
        V: Optional[torch.Tensor],
        evals: Optional[torch.Tensor],
    ) -> torch.Tensor:
        """Apply M = V @ diag(evals) @ V.T transform."""
        if V is None:
            return X
        XV = X @ V.to(X.device)
        XV_scaled = XV * evals.to(X.device).unsqueeze(0)
        return XV_scaled @ V.T.to(X.device)

    def fit(
        self,
        X_train: torch.Tensor,
        y_train: torch.Tensor,
        X_val: Optional[torch.Tensor] = None,
        y_val: Optional[torch.Tensor] = None,
        n_iters: int = 5,
        top_k: Optional[int] = None,
        early_stop: bool = True,
        early_stop_threshold: float = 1.1,
    ) -> DualRFMResult:
        """Fit dual RFM to learn discriminative directions.

        Args:
            X_train: [N, D] training features
            y_train: [N] binary labels (1=target, 0=other)
            X_val: Optional [N_val, D] validation features
            y_val: Optional [N_val] validation labels
            n_iters: RFM iterations (default: 5)
            top_k: Number of eigenvectors to return
            early_stop: Stop if validation AUC degrades
            early_stop_threshold: Early stop threshold (default: 1.1)

        Returns:
            DualRFMResult with eigenvectors and sign correction
        """
        device = X_train.device
        N, D = X_train.shape

        # Normalize
        mean = X_train.mean(dim=0).cpu().numpy()
        std = (X_train.std(dim=0) + 1e-8).cpu().numpy()
        X_norm = (X_train - torch.from_numpy(mean).to(device)) / torch.from_numpy(std).to(device)

        self._log(f"DualRFM: N={N}, D={D}")
        self._log(f"  Memory: {N*N*4/1e6:.1f}MB (dual) vs {D*D*4/1e9:.1f}GB (primal)")

        # Initialize M as identity
        M_eigenvectors = None
        M_eigenvalues = None

        # Best model tracking
        best_auc = -1.0
        best_iter = -1
        best_eigenvectors = None
        best_eigenvalues = None

        # Validation setup
        has_val = X_val is not None and y_val is not None
        if has_val:
            X_val_norm = (X_val - torch.from_numpy(mean).to(device)) / torch.from_numpy(std).to(device)
            y_val_np = y_val.cpu().numpy() if torch.is_tensor(y_val) else y_val

        for it in range(n_iters):
            self._log(f"\n  Iteration {it+1}/{n_iters}")

            # Compute kernel K using Mahalanobis distance
            Xm = self._mahalanobis_transform(X_norm, M_eigenvectors, M_eigenvalues)
            xm_norm_sq = (Xm * X_norm).sum(dim=1, keepdim=True)
            dists_sq = xm_norm_sq + xm_norm_sq.T - 2 * (Xm @ X_norm.T)
            K = torch.exp(-dists_sq.clamp(min=0).sqrt() / self.bandwidth)

            # KRR: alpha = (K + reg*I)^-1 @ y
            K_reg = K + self.reg * torch.eye(N, device=device)
            alpha = torch.linalg.solve(K_reg, y_train.unsqueeze(1)).squeeze(1)

            # Compute gradients
            grads = self.kernel_grads.compute_grads(X_norm, X_norm, Xm, Xm, alpha)

            # Center gradients
            if self.center_grads:
                grads = grads - grads.mean(dim=0, keepdim=True)

            # Gram matrix G = grads @ grads.T
            G = grads @ grads.T
            agop_diag = (grads ** 2).sum(dim=0)
            agop_max = agop_diag.max().clamp(min=1e-30)
            G = G / agop_max

            # Add diagonal regularization
            G.diagonal().add_(1e-8)

            # Eigendecomposition using SVD
            MAX_DIM_FOR_SVD = 5000
            max_k = min(MAX_DIM_FOR_SVD, N)
            if N > MAX_DIM_FOR_SVD:
                self._log(f"  Using svd_lowrank (N={N} > {MAX_DIM_FOR_SVD}), q={max_k}")
                U_sorted, eigenvalues, _ = torch.svd_lowrank(G.float(), q=max_k)
            else:
                U, eigenvalues, _ = torch.linalg.svd(G.float(), full_matrices=False)
                eigenvalues = eigenvalues[:max_k]
                U_sorted = U[:, :max_k]

            # Update M
            M_eigenvalues = eigenvalues
            V_raw = grads.T @ U_sorted
            eigenvalues_safe = eigenvalues.clamp(min=1e-10)
            M_eigenvectors = V_raw / (eigenvalues_safe * agop_max).sqrt().unsqueeze(0)

            # Store current eigenvectors
            if top_k is None:
                curr_eigenvectors = M_eigenvectors.cpu()
                curr_eigenvalues = M_eigenvalues.cpu()
            else:
                k = min(top_k, M_eigenvalues.numel())
                curr_eigenvectors = M_eigenvectors[:, :k].cpu()
                curr_eigenvalues = M_eigenvalues[:k].cpu()

            # Validation
            if has_val:
                X_val_m = self._mahalanobis_transform(X_val_norm, M_eigenvectors, M_eigenvalues)
                val_norm_sq = (X_val_m * X_val_norm).sum(dim=1, keepdim=True)
                train_norm_sq = (Xm * X_norm).sum(dim=1, keepdim=True)
                dists_val = val_norm_sq + train_norm_sq.T - 2 * (X_val_m @ X_norm.T)
                K_val = torch.exp(-dists_val.clamp(min=0).sqrt() / self.bandwidth)

                y_pred = (K_val @ alpha).cpu().numpy()

                from sklearn.metrics import roc_auc_score
                iter_auc = roc_auc_score(y_val_np, y_pred)
                self._log(f"  Val AUC: {iter_auc:.4f}")

                if iter_auc > best_auc:
                    best_auc = iter_auc
                    best_iter = it + 1
                    best_eigenvectors = curr_eigenvectors.clone()
                    best_eigenvalues = curr_eigenvalues.clone()
                    self._log(f"  New best at iter {best_iter}!")

                if early_stop and it > 0 and iter_auc < best_auc / early_stop_threshold:
                    self._log(f"  Early stop: AUC ({iter_auc:.4f}) < {best_auc:.4f}/{early_stop_threshold}")
                    break
            else:
                best_eigenvectors = curr_eigenvectors
                best_eigenvalues = curr_eigenvalues

        # Determine sign for top eigenvector
        top_eigvec = best_eigenvectors[:, 0].numpy()
        if has_val:
            X_for_sign = X_val.cpu().numpy()
            y_for_sign = y_val_np
        else:
            X_for_sign = X_train.cpu().numpy()
            y_for_sign = y_train.cpu().numpy()

        sign, pearson_corr = determine_direction_sign(X_for_sign, y_for_sign, top_eigvec, mean, std)
        self._log(f"\n  Direction sign: {sign} (Pearson corr: {pearson_corr:.4f})")

        return DualRFMResult(
            eigenvectors=best_eigenvectors,
            eigenvalues=best_eigenvalues,
            mean=mean,
            std=std,
            best_auc=best_auc,
            best_iter=best_iter,
            sign=sign,
            pearson_corr=pearson_corr,
        )


def compute_guidance_direction(
    X_target: torch.Tensor,
    X_other: torch.Tensor,
    X_val_target: Optional[torch.Tensor] = None,
    X_val_other: Optional[torch.Tensor] = None,
    bandwidth: float = 100.0,
    reg: float = 1e-3,
    n_iters: int = 5,
    center_grads: bool = True,
    verbose: bool = True,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float, int]:
    """Convenience function to compute guidance direction.

    Returns:
        direction: [D] normalized guidance direction (sign-corrected)
        mean: [D] normalization mean
        std: [D] normalization std
        auc: validation AUC (-1 if no validation)
        sign: +1 or -1 (already applied to direction)
    """
    X_train = torch.cat([X_target, X_other], dim=0)
    y_train = torch.cat([
        torch.ones(len(X_target), device=X_train.device),
        torch.zeros(len(X_other), device=X_train.device),
    ])

    X_val = None
    y_val = None
    if X_val_target is not None and X_val_other is not None:
        X_val = torch.cat([X_val_target, X_val_other], dim=0)
        y_val = torch.cat([
            torch.ones(len(X_val_target), device=X_val.device),
            torch.zeros(len(X_val_other), device=X_val.device),
        ])

    rfm = DualRFM(
        bandwidth=bandwidth,
        reg=reg,
        center_grads=center_grads,
        verbose=verbose,
    )

    result = rfm.fit(X_train, y_train, X_val, y_val, n_iters=n_iters, top_k=1)

    direction = result.eigenvectors[:, 0].numpy()
    direction = direction / np.linalg.norm(direction)
    direction = direction * result.sign

    return direction, result.mean, result.std, result.best_auc, result.sign
