import numpy as np
import torch
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.metrics import accuracy_score
from scipy.stats import mode
import builtins
import os
import sys
import torch.nn as nn
import torch.optim as optim
from typing import Dict
from typing import Optional

sys.path.append(builtins.ROOT_PATH)
from exp.dataloader import check_device, log

def mmd_u(K, n, m, is_var=False):
    # Extract submatrices for XX, YY, and XY
    K_XX = K[:n, :n]
    K_YY = K[n:, n:]
    K_XY = K[:n, n:]

    # Ensure diagonal elements are zero (no self-comparison) for XX and YY
    K_XX.fill_diagonal_(0)
    K_YY.fill_diagonal_(0)

    # Calculate each term of the MMD_u^2
    mmd_u_squared = (K_XX.sum() / (n * (n - 1))) + \
        (K_YY.sum() / (m * (m - 1))) - (2 * K_XY.sum() / (n * m))

    return mmd_u_squared



class KernelMeanEmbeddingWitness:

    def __init__(self, X_ref: torch.Tensor, bandwidth: float = None):
        self.X_ref = X_ref
        self.bandwidth = bandwidth or self._median_heuristic(X_ref)
        self.device = X_ref.device

        # Pre-compute mean kernel value within X_ref
        K_xx = self._compute_kernel(X_ref, X_ref)
        self.mean_K_xx = K_xx.mean().item()

    def _median_heuristic(self, X: torch.Tensor) -> float:
        with torch.no_grad():
            dists = torch.cdist(X, X)
            return torch.median(dists[dists > 0]).item()

    def _compute_kernel(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
        """Gaussian RBF kernel - CHARACTERISTIC (guarantees consistency)."""
        dists = torch.cdist(X, Y)
        return torch.exp(-dists**2 / (2 * self.bandwidth**2))

    def __call__(self, Z: torch.Tensor) -> torch.Tensor:
        """
        Compute squared distance to kernel mean embedding for each point.

        w(z) = k(z,z) - 2·mean_x k(z,x) + mean_{x,x'} k(x,x')
             = 1 - 2·mean_x k(z,x) + const  (for RBF, k(z,z)=1)

        Higher values = further from the reference distribution's center.
        """
        K_zx = self._compute_kernel(Z, self.X_ref)
        mean_K_zx = K_zx.mean(dim=1)

        # w(z) = 1 - 2*mean_K_zx + mean_K_xx
        # We can drop the constant for ranking purposes
        return 1 - 2 * mean_K_zx + self.mean_K_xx


class MultiScaleKMEWitness:
    """
    Multi-scale Kernel Mean Embedding for better finite-sample power.

    Uses multiple bandwidths to capture both local and global structure.
    This improves power at small sample sizes while maintaining consistency.
    """

    def __init__(self, X_ref: torch.Tensor, n_scales: int = 5):
        self.X_ref = X_ref
        self.device = X_ref.device

        # Compute base bandwidth
        with torch.no_grad():
            dists = torch.cdist(X_ref, X_ref)
            base_bw = torch.median(dists[dists > 0]).item()

        # Multiple scales: 0.25x, 0.5x, 1x, 2x, 4x
        scales = [0.25, 0.5, 1.0, 2.0, 4.0][:n_scales]
        self.bandwidths = [base_bw * s for s in scales]

        # Pre-compute mean kernel values for each scale
        self.mean_K_xx = []
        for bw in self.bandwidths:
            K_xx = torch.exp(-dists**2 / (2 * bw**2))
            self.mean_K_xx.append(K_xx.mean().item())

    def __call__(self, Z: torch.Tensor) -> torch.Tensor:
        """Average KME distance across scales."""
        dists_zx = torch.cdist(Z, self.X_ref)

        total = torch.zeros(len(Z), device=self.device)
        for bw, mean_kxx in zip(self.bandwidths, self.mean_K_xx):
            K_zx = torch.exp(-dists_zx**2 / (2 * bw**2))
            mean_K_zx = K_zx.mean(dim=1)
            total += 1 - 2 * mean_K_zx + mean_kxx

        return total / len(self.bandwidths)


class KernelWitness:
    def __init__(self, X_ref: torch.Tensor, bandwidth: float = None):
        self.X_ref = X_ref
        self.bandwidth = bandwidth or self._median_heuristic(X_ref)

    def _median_heuristic(self, X: torch.Tensor) -> float:
        with torch.no_grad():
            dists = torch.cdist(X, X)
            return torch.median(dists[dists > 0]).item()

    def __call__(self, Z: torch.Tensor) -> torch.Tensor:
        dists = torch.cdist(Z, self.X_ref)
        K = torch.exp(-dists**2 / (2 * self.bandwidth**2))
        return K.mean(dim=1)


class MahalanobisWitness:
    def __init__(self, X_ref: torch.Tensor):
        self.mu = X_ref.mean(dim=0)
        n, d = X_ref.shape
        centered = X_ref - self.mu
        cov = (centered.T @ centered) / (n - 1)
        cov = cov + 1e-6 * torch.eye(d, device=X_ref.device)
        self.cov_inv = torch.linalg.pinv(cov)

    def __call__(self, Z: torch.Tensor) -> torch.Tensor:
        diff = Z - self.mu
        return (diff @ self.cov_inv * diff).sum(dim=1)


class StableLOFWitness:
    def __init__(self, X_ref: torch.Tensor, k: int = 30):
        self.X_ref = X_ref
        self.k = min(k, len(X_ref) - 1)
        self._compute_reference_lrd()

    def _compute_reference_lrd(self):
        dists = torch.cdist(self.X_ref, self.X_ref)
        knn_dists, knn_idx = torch.topk(dists, self.k + 1, largest=False, dim=1)
        self.k_dist_ref = knn_dists[:, -1]
        reach_dists = torch.maximum(knn_dists[:, 1:], self.k_dist_ref[knn_idx[:, 1:]])
        self.lrd_ref = 1.0 / (reach_dists.median(dim=1).values + 1e-10)

    def __call__(self, Z: torch.Tensor) -> torch.Tensor:
        dists = torch.cdist(Z, self.X_ref)
        knn_dists, knn_idx = torch.topk(dists, self.k, largest=False, dim=1)
        reach_dists = torch.maximum(knn_dists, self.k_dist_ref[knn_idx])
        lrd_z = 1.0 / (reach_dists.median(dim=1).values + 1e-10)
        neighbor_lrd = self.lrd_ref[knn_idx]
        return neighbor_lrd.median(dim=1).values / (lrd_z + 1e-10)


class KNNDistanceWitness:
    def __init__(self, X_ref: torch.Tensor, k: int = 10):
        self.X_ref = X_ref
        self.k = min(k, len(X_ref))

    def __call__(self, Z: torch.Tensor) -> torch.Tensor:
        dists = torch.cdist(Z, self.X_ref)
        knn_dists, _ = torch.topk(dists, self.k, largest=False, dim=1)
        return knn_dists.mean(dim=1)


class LOTT:

    def __init__(self, alpha: float = 0.05, n_permutations: int = 500):
        self.alpha = alpha
        self.n_permutations = n_permutations

    def fit(self, X_train: torch.Tensor, X_calib: torch.Tensor, X_hold: torch.Tensor):
        self.device = X_train.device
        base_bw = torch.cdist(X_train, X_train).median().item()

        # Balanced witness ensemble for both global and local alternatives
        self.witnesses = [
            # CONSISTENT witnesses (guarantee power → 1 for location alternatives)
            KernelMeanEmbeddingWitness(X_train, bandwidth=base_bw),
            # MultiScaleKMEWitness(X_train, n_scales=3),

            # LOCAL structure witnesses (for covariance/density alternatives like blob)
            StableLOFWitness(X_train, k=20),  # Local density - key for blob!
            KNNDistanceWitness(X_train, k=5),  # Smaller k for local structure

            # # Global witnesses
            MahalanobisWitness(X_train),
        ]
        self.K = len(self.witnesses)

        # Compute calibration statistics
        calib_scores = self._compute_scores(X_calib)
        self.mu = calib_scores.mean(axis=0)
        self.std = calib_scores.std(axis=0) + 1e-10

        # Sign correction: ALL witnesses measure "outlierness"
        # Higher value = more different from reference = evidence for H₁
        self.signs = np.array([1, 1, 1, 1])  # All positive!


        # Pre-compute holdout scores
        self.hold_scores = self._compute_scores(X_hold)

    def _compute_scores(self, X: torch.Tensor) -> np.ndarray:
        scores = []
        for w in self.witnesses:
            scores.append(w(X).detach().cpu().numpy())
        return np.stack(scores, axis=1)

    def _compute_stat(self, scores: np.ndarray) -> float:
        """Standardized sum with sign correction."""
        standardized = (scores - self.mu) / self.std * self.signs
        return standardized.sum(axis=1).mean()

    def test(self, Y: torch.Tensor) -> Dict:
        m = len(Y)
        y_scores = self._compute_scores(Y)
        test_stat = self._compute_stat(y_scores)

        # Pool holdout + Y for permutation (CRITICAL for Type-I control)
        pooled = np.vstack([self.hold_scores, y_scores])
        n_total = len(pooled)

        null_stats = []
        for _ in range(self.n_permutations):
            perm = np.random.permutation(n_total)
            pseudo_Y = pooled[perm[-m:]]
            null_stats.append(self._compute_stat(pseudo_Y))

        null_stats = np.array(null_stats)
        p_value = (np.sum(null_stats >= test_stat) + 1) / (self.n_permutations + 1)

        return {
            'reject': p_value < self.alpha,
            'p_value': p_value,
            'statistic': test_stat
        }


class LOTTWithSelection:

    # Will be dynamically updated in fit() to include landmark names
    WITNESS_NAMES = ['KME', 'LOF', 'KNN', 'Mahalanobis']

    def __init__(self, alpha: float = 0.05, n_permutations: int = 500,
                 selection_method: str = 'top_n', n_select: int = 2,
                 variance_threshold: float = None, verbose: bool = True):
        """
        Args:
            alpha: Significance level
            n_permutations: Number of permutations for p-value estimation
            selection_method: 'top_n', 'precision_weight', or 'threshold'
            n_select: Number of witnesses to select (for 'top_n')
            variance_threshold: Max variance allowed (for 'threshold')
            verbose: Print selection details
        """
        self.alpha = alpha
        self.n_permutations = n_permutations
        self.selection_method = selection_method
        self.n_select = n_select
        self.variance_threshold = variance_threshold
        self.verbose = verbose

    def fit(self, X_train: torch.Tensor, X_calib: torch.Tensor, X_hold: torch.Tensor,
            M: int = 10, subset_size: int = 10):
        """
        Fit the LOTT model with witness selection.

        Steps:
        1. Construct all witnesses (including M landmark witnesses)
        2. Compute scores on calibration set
        3. Compute variance for each witness
        4. Select/weight witnesses based on stability
        5. Prepare for testing

        Args:
            X_train: Training data
            X_calib: Calibration data
            X_hold: Holdout data
            M: Number of landmark witnesses to create (default: 10)
            subset_size: Fixed number of samples for each landmark witness (default: 100)
        """
        self.device = X_train.device
        base_bw = torch.cdist(X_train, X_train).median().item()
        self.M = M
        self.subset_size = subset_size


        landmark_cfe = LandmarkKernelCFE(X_train, M=M, subset_size=subset_size, bandwidth=base_bw)
        landmark_witnesses = landmark_cfe.get_individual_witnesses()

        # Build witness list: base witnesses + M landmark witnesses
        self.all_witnesses = [
            KernelMeanEmbeddingWitness(X_train, bandwidth=base_bw),
            StableLOFWitness(X_train, k=20),
            KNNDistanceWitness(X_train, k=5),
            MahalanobisWitness(X_train),
        ] + landmark_witnesses

        # Update WITNESS_NAMES to include landmark witnesses
        base_names = ['KME', 'LOF', 'KNN', 'Mahalanobis']
        landmark_names = [f'Landmark_{l}' for l in range(M)]
        self.WITNESS_NAMES = base_names + landmark_names

        # All witnesses have positive sign (higher = more different from reference)
        self.all_signs = np.ones(len(self.all_witnesses))

        all_calib_scores = self._compute_all_scores(X_calib)
        # print(all_calib_scores)

        self.witness_variances = all_calib_scores.var(axis=0)
        # print(self.witness_variances)

        self._select_witnesses(all_calib_scores)

        calib_scores = self._compute_scores(X_calib)
        self.mu = calib_scores.mean(axis=0)
        self.std = calib_scores.std(axis=0) + 1e-10

        # Pre-compute holdout scores
        self.hold_scores = self._compute_scores(X_hold)

        if self.verbose:
            self._print_selection_info()

    def _select_witnesses(self, calib_scores: np.ndarray):

        variances = self.witness_variances
        n_witnesses = len(self.all_witnesses)

        if self.selection_method == 'top_n':
            # Select top-n witnesses with lowest variance
            n_select = min(self.n_select, n_witnesses)
            self.selected_indices = np.argsort(variances)[:n_select]
            self.weights = np.ones(n_select)  # Equal weights for selected

        elif self.selection_method == 'precision_weight':
            # Weight all witnesses inversely by variance (precision weighting)
            self.selected_indices = np.arange(n_witnesses)
            # Weights = 1 / variance (precision)
            # Normalize to sum to number of witnesses for comparability
            precisions = 1.0 / (variances + 1e-10)
            self.weights = precisions / precisions.sum() * n_witnesses

        elif self.selection_method == 'threshold':
            # Select witnesses with variance below threshold
            if self.variance_threshold is None:
                # Default: use median variance as threshold
                self.variance_threshold = np.median(variances)
            self.selected_indices = np.where(variances <= self.variance_threshold)[0]
            if len(self.selected_indices) == 0:
                # Fallback: at least select the lowest variance witness
                self.selected_indices = np.array([np.argmin(variances)])
            self.weights = np.ones(len(self.selected_indices))
        else:
            raise ValueError(f"Unknown selection method: {self.selection_method}")

        # Build selected witnesses list
        self.witnesses = [self.all_witnesses[i] for i in self.selected_indices]
        self.signs = self.all_signs[self.selected_indices]
        self.K = len(self.witnesses)

    def _print_selection_info(self):
        """Print information about witness selection."""
        print("\n" + "="*60)
        print("WITNESS SELECTION SUMMARY")
        print("="*60)
        print(f"Selection method: {self.selection_method}")
        print("\nAll witnesses variance on calibration set:")
        for i, (name, var) in enumerate(zip(self.WITNESS_NAMES, self.witness_variances)):
            selected = "✓" if i in self.selected_indices else " "
            print(f"  [{selected}] {name}: variance = {var:.6f}")

        print(f"\nSelected {len(self.selected_indices)} witnesses:")
        for idx, w_idx in enumerate(self.selected_indices):
            weight_str = f", weight={self.weights[idx]:.3f}" if self.selection_method == 'precision_weight' else ""
            print(f"  - {self.WITNESS_NAMES[w_idx]} (variance={self.witness_variances[w_idx]:.6f}{weight_str})")
        print("="*60 + "\n")

    def _compute_all_scores(self, X: torch.Tensor) -> np.ndarray:
        """Compute scores for all witnesses."""
        scores = []
        for w in self.all_witnesses:
            sum = []
            for _ in range(20):
                n = len(X)
                idx = np.random.choice(n, size=n//2, replace=False)
                X_sample = X[idx]
                sum.append(w(X_sample).detach().cpu().numpy().mean(axis=0))
            scores.append(np.array(sum))
        return np.stack(scores, axis=1)

    def _compute_scores(self, X: torch.Tensor) -> np.ndarray:
        """Compute scores for selected witnesses only."""
        scores = []
        for w in self.witnesses:
            scores.append(w(X).detach().cpu().numpy())
        return np.stack(scores, axis=1)

    def _compute_stat(self, scores: np.ndarray) -> float:
        """
        Weighted sum of mean squared scores for each witness.

        Aggregation method:
        1. Standardize: z_ik = (score_ik - mu_k) / std_k * sign_k
        2. For each witness k, compute mean of squares over samples: MSS_k = (1/m) Σ_i (z_ik)²
        3. Weighted sum over witnesses: stat = Σ_k w_k * MSS_k

        Args:
            scores: shape (m, K) where m = samples, K = selected witnesses
        Returns:
            test statistic (scalar)
        """
        standardized = (scores - self.mu) / self.std * self.signs  # Shape (m, K)
        # Mean of squares for each witness (mean over samples)
        mean_of_squares = (standardized ** 2).mean(axis=0)  # Shape (K,)
        # Weighted sum over witnesses
        return (self.weights * mean_of_squares).sum()

    def test(self, Y: torch.Tensor) -> Dict:
        """
        Perform the two-sample test using selected witnesses.
        """
        m = len(Y)
        y_scores = self._compute_scores(Y)
        test_stat = self._compute_stat(y_scores)

        # Pool holdout + Y for permutation
        pooled = np.vstack([self.hold_scores, y_scores])
        n_total = len(pooled)

        null_stats = []
        for _ in range(self.n_permutations):
            perm = np.random.permutation(n_total)
            pseudo_Y = pooled[perm[-m:]]
            null_stats.append(self._compute_stat(pseudo_Y))

        null_stats = np.array(null_stats)
        p_value = (np.sum(null_stats >= test_stat) + 1) / (self.n_permutations + 1)

        return {
            'reject': p_value < self.alpha,
            'p_value': p_value,
            'statistic': test_stat,
            'selected_witnesses': [self.WITNESS_NAMES[i] for i in self.selected_indices],
            'witness_variances': dict(zip(self.WITNESS_NAMES, self.witness_variances))
        }

    def get_selection_stats(self) -> Dict:
        """Return detailed statistics about witness selection."""
        return {
            'method': self.selection_method,
            'all_variances': dict(zip(self.WITNESS_NAMES, self.witness_variances)),
            'selected_indices': self.selected_indices.tolist(),
            'selected_names': [self.WITNESS_NAMES[i] for i in self.selected_indices],
            'weights': self.weights.tolist(),
            'variance_ranking': [self.WITNESS_NAMES[i] for i in np.argsort(self.witness_variances)]
        }


class SubsetKernelWitness:
    def __init__(self, subset: torch.Tensor, bandwidth: float, device: torch.device):
        """
        Args:
            subset: subset of reference points, shape (n_subset, d)
            bandwidth: RBF kernel bandwidth
            device: torch device
        """
        self.subset = subset  # Shape (n_subset, d)
        self.bandwidth = bandwidth
        self.device = device

    def __call__(self, Z: torch.Tensor) -> torch.Tensor:
        """
        Compute mean kernel similarity to the subset.

        Args:
            Z: tensor of shape (m, d)
        Returns:
            scores: tensor of shape (m,) - mean kernel similarity to subset
        """
        if Z.device != self.device:
            Z = Z.to(self.device)
        dists = torch.cdist(Z, self.subset)  # Shape (m, n_subset)
        K = torch.exp(-dists**2 / (2.0 * self.bandwidth**2))
        return K.mean(dim=1)  # Shape (m,)


class LandmarkKernelCFE:

    def __init__(
        self,
        X_ref: torch.Tensor,
        M: int = 10,
        subset_size: int = 100,
        bandwidth: Optional[float] = None,
        seed: Optional[int] = None,
    ):
        """
        Args:
            X_ref: reference tensor of shape (n, d)
            M: number of landmark witnesses to create (default: 10)
            subset_size: fixed number of samples for each subset (default: 100)
            bandwidth: RBF bandwidth; if None, uses median heuristic on X_ref
            seed: optional random seed for reproducibility
        """
        if X_ref.ndim != 2:
            raise ValueError(f"X_ref must be 2D (n,d). Got shape={tuple(X_ref.shape)}")
        n = X_ref.shape[0]
        if not (1 <= subset_size <= n):
            raise ValueError(f"subset_size must be in [1, n] where n={n}. Got {subset_size}")

        self.X_ref = X_ref
        self.device = X_ref.device
        self.M = M
        self.subset_size = subset_size
        self.bandwidth = float(bandwidth) if bandwidth is not None else self._median_heuristic(X_ref)

        # Create M random subsets
        self.subsets = self._create_subsets(X_ref, M, self.subset_size, seed)

    def _create_subsets(
        self,
        X_ref: torch.Tensor,
        M: int,
        subset_size: int,
        seed: Optional[int],
    ) -> list:
        """Create M random subsets of X_ref, each of size subset_size."""
        n = X_ref.shape[0]
        subsets = []

        # Set random seed if provided
        if seed is not None:
            torch.manual_seed(seed)
            np.random.seed(seed)

        for _ in range(M):
            # Random sampling without replacement for each subset
            indices = np.random.choice(n, size=subset_size, replace=False)
            subset = X_ref[indices]  # Shape (subset_size, d)
            subsets.append(subset)

        return subsets

    def _median_heuristic(self, X: torch.Tensor) -> float:
        """Median heuristic for RBF bandwidth on reference sample."""
        with torch.no_grad():
            dists = torch.cdist(X, X)
            d = dists[dists > 0]
            if d.numel() == 0:
                # All points identical; fallback bandwidth
                return 1.0
            return torch.median(d).item()

    def _compute_kernel(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
        """Gaussian RBF kernel."""
        dists = torch.cdist(X, Y)
        return torch.exp(-dists**2 / (2.0 * self.bandwidth**2))

    def __call__(self, Z: torch.Tensor) -> torch.Tensor:
        """
        Return landmark similarity scores for each z in Z.

        Args:
            Z: tensor of shape (m, d)

        Returns:
            scores: tensor of shape (m, M), where scores[:, l] = mean_{x ∈ S_l} k(z, x)
        """
        if Z.ndim != 2:
            raise ValueError(f"Z must be 2D (m,d). Got shape={tuple(Z.shape)}")
        if Z.device != self.device:
            Z = Z.to(self.device)

        # Compute mean kernel similarity to each subset
        all_scores = []
        for subset in self.subsets:
            K = self._compute_kernel(Z, subset)  # Shape (m, subset_size)
            scores = K.mean(dim=1)  # Shape (m,)
            all_scores.append(scores)

        return torch.stack(all_scores, dim=1)  # Shape (m, M)

    def get_individual_witnesses(self) -> list:
        """
        Return M individual SubsetKernelWitness objects.
        Each witness returns shape (m,) for input shape (m, d).

        Use this method to expand LandmarkKernelCFE into M separate witnesses
        for use in LOTTWithSelection.

        Returns:
            List of M SubsetKernelWitness objects
        """
        witnesses = []
        for l in range(self.M):
            witness = SubsetKernelWitness(
                subset=self.subsets[l],
                bandwidth=self.bandwidth,
                device=self.device
            )
            witnesses.append(witness)
        return witnesses

    def get_subsets(self) -> list:
        """Return list of M subsets, each of shape (subset_size, d)."""
        return self.subsets

    def extra_repr(self) -> str:
        return f"M={self.M}, subset_size={self.subset_size}, bandwidth={self.bandwidth:.4g}, device={self.device}"