
from __future__ import annotations

import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler, RobustScaler
from datetime import datetime
from typing import Dict, Optional, Tuple, List
from dataclasses import dataclass, field

# Local imports
from load_dataset_grinsztajn import load_dataset
from geo_conditional_quantile_icnn_gpu_amp_stable_widecore import (
    ICNNGeometricQuantileMap,
    sample_uniform_ball,
)
from GCQR_utils import (
    ConformalPICNN,
    approximate_region_boundary,
    distance_to_region_uopt,
)
from train_cqr_v2_utils import (
    split_train_cal_test,
    TrainCQRv2Config,
    train_phase_A_Q_only,
)
from config import get_train_config_gcqr
from math import gamma
# WSC utilities
try:
    from worst_slab_cov_gpu import calculer_wsc_regression
    WSC_AVAILABLE = True
except Exception:
    WSC_AVAILABLE = False
    print("[ConformalPICNN Eval] WSC not available - will be skipped")


# ==============================================================================
# DIAGNOSTICS DATACLASS
# ==============================================================================

@dataclass
class GCQR3Diagnostics:
    """Structured diagnostics for GCQR3 pipeline."""
    # Q-spread tracking (should be > 0.5 for non-collapsed model)
    q_spread_initial: float = 0.0
    q_spread_final: float = 0.0
    q_spread_history: List[float] = field(default_factory=list)
    
    # TR conditioning
    tr_type: str = "none"  # "global", "conditional", "stabilized", "none"
    tr_log_det_mean: float = 0.0
    tr_log_det_std: float = 0.0
    tr_cond_num_mean: float = 1.0
    tr_cond_num_max: float = 1.0
    tr_fallback_rate: float = 0.0  # For StabilizedConditionalTR
    
    # Volume metrics per point (for per-point correction)
    log_jacobians: List[float] = field(default_factory=list)
    volumes_transformed: List[float] = field(default_factory=list)
    volumes_corrected: List[float] = field(default_factory=list)
    
    # Calibration stats
    cal_scores_mean: float = 0.0
    cal_scores_std: float = 0.0
    cal_scores_min: float = 0.0
    cal_scores_max: float = 0.0
    
    # Model info
    model_params: int = 0
    width: int = 0
    depth: int = 0
    mu: float = 0.0
    
    def is_collapsed(self) -> bool:
        """Check if Q appears collapsed."""
        return self.q_spread_final < 0.1
    
    def is_tr_ill_conditioned(self) -> bool:
        """Check if TR has high condition number."""
        return self.tr_cond_num_max > 1e6
    
    def print_summary(self):
        """Print diagnostic summary."""
        print("\n" + "=" * 60)
        print(" GCQR3 DIAGNOSTICS SUMMARY")
        print("=" * 60)
        print(f"  Q-spread:    initial={self.q_spread_initial:.4f}, final={self.q_spread_final:.4f}")
        if self.is_collapsed():
            print(f"    ⚠️ WARNING: Q collapsed! Increase width/depth/mu.")
        print(f"  TR type:     {self.tr_type}")
        if self.tr_type != "none":
            print(f"    log|det|:  mean={self.tr_log_det_mean:+.2f}, std={self.tr_log_det_std:.2f}")
            print(f"    cond num:  mean={self.tr_cond_num_mean:.2e}, max={self.tr_cond_num_max:.2e}")
            if self.is_tr_ill_conditioned():
                print(f"    ⚠️ WARNING: High condition number!")
            if self.tr_fallback_rate > 0:
                print(f"    fallback:  {self.tr_fallback_rate*100:.1f}%")
        print(f"  Cal scores:  mean={self.cal_scores_mean:.4f}, std={self.cal_scores_std:.4f}")
        print(f"  Model:       width={self.width}, depth={self.depth}, mu={self.mu}, params={self.model_params}")
        print("=" * 60)


# ==============================================================================
# Volume Estimation via Monte Carlo - ALIGNED with CGRR3 style
# ==============================================================================


# ==============================================================================
# Helper functions for center/radius estimation (from CGRR3)
# ==============================================================================

@torch.no_grad()
def estimate_conditional_center(
    model_Q: ICNNGeometricQuantileMap, 
    x_point: torch.Tensor, 
    device: torch.device, 
    n_samples: int = 200
) -> np.ndarray:
    """
    Estime le centre de la distribution conditionnelle P(Y|X=x)
    via échantillonnage de Q(x,u) avec u ~ Unif(Ball).
    
    ALIGNED with CGRR3 estimate_conditional_center.
    """
    model_Q.eval()
    d_y = model_Q.y_dim
    
    U_samples = sample_uniform_ball(n_samples, d_y, device=device)
    X_rep = x_point.unsqueeze(0).expand(n_samples, -1).to(device)
    Y_samples = model_Q(X_rep, U_samples).cpu().numpy()
    center = np.median(Y_samples, axis=0)
    
    return center


@torch.no_grad()
def estimate_local_radius(
    model_Q: ICNNGeometricQuantileMap, 
    x_point: torch.Tensor, 
    device: torch.device, 
    n_samples: int = 200, 
    quantile: float = 0.99, 
    max_u_norm: float = 2.0
) -> np.ndarray:
    """
    Estime le rayon local de dispersion conditionnelle autour de x
    en échantillonnant largement dans l'espace U.
    
    ALIGNED with CGRR3 estimate_local_radius.
    
    Args:
        model_Q: ICNNGeometricQuantileMap
        x_point: torch.Tensor de shape (d_x,)
        device: torch device
        n_samples: nombre d'échantillons MC
        quantile: quantile pour estimer le rayon (0.99 = couvrir 99% de la région)
        max_u_norm: rayon max dans l'espace U pour l'exploration
    
    Returns:
        radius: np.ndarray de shape (d_y,) avec rayon par dimension
    """
    model_Q.eval()
    d_y = model_Q.y_dim
    
    # Échantillonner largement dans U (rayon max_u_norm) pour capturer toute la région
    U_samples = sample_uniform_ball(n_samples, d_y, device=device)
    U_samples = U_samples * max_u_norm
    
    X_rep = x_point.unsqueeze(0).expand(n_samples, -1).to(device)
    Y_samples = model_Q(X_rep, U_samples).cpu().numpy()
    
    center = np.median(Y_samples, axis=0)
    deviations = np.abs(Y_samples - center)
    radius = np.quantile(deviations, quantile, axis=0)
    
    radius = radius * 2.0
    radius = np.maximum(radius, 1e-4)
    
    return radius


# ==============================================================================
# Euclidean distance and inside test functions
# ==============================================================================

@torch.no_grad()
def _compute_euclidean_distance_to_boundary(
    samples: torch.Tensor,      # (N, d) - MC samples
    boundary: torch.Tensor,     # (M, d) - boundary points
) -> torch.Tensor:
    """
    Calcule la vraie distance euclidienne de chaque sample au point de frontière le plus proche.
    
    d(y, ∂C) = min_{z ∈ boundary} ||y - z||
    
    Args:
        samples: (N, d) points MC à tester
        boundary: (M, d) points de la frontière
    
    Returns:
        (N,) distances euclidiennes (toujours >= 0)
    """
    # Ensure consistent dtype
    dtype = torch.float32
    samples = samples.to(dtype)
    boundary = boundary.to(dtype)
    
    # Calcul par chunks pour éviter OOM en haute dimension
    N = samples.shape[0]
    M = boundary.shape[0]
    
    chunk_size = min(1000, N)
    distances = []
    
    for i in range(0, N, chunk_size):
        chunk = samples[i:i+chunk_size]  # (chunk, d)
        # Distance de chaque point du chunk à tous les points de frontière
        # (chunk, 1, d) - (1, M, d) -> (chunk, M, d) -> (chunk, M)
        dists = torch.linalg.norm(chunk.unsqueeze(1) - boundary.unsqueeze(0), dim=-1)
        # Distance minimale pour chaque point
        min_dists = dists.min(dim=1).values  # (chunk,)
        distances.append(min_dists)
    
    return torch.cat(distances, dim=0)


# NOTE: _is_inside_convex_region and _compute_signed_distances_mc removed
# They used unreliable geometric heuristics that fail in high dimensions.
# Now using distance_to_region_uopt from GCQR_utils which uses gradient optimization.
    



def _unit_ball_volume(d: int) -> float:
    """Volume of unit ball in d dimensions: π^(d/2) / Γ(d/2 + 1)."""
    return (np.pi ** (d / 2)) / gamma(d / 2 + 1)


def estimate_region_volume_mc_gcqr(
    model_Q: ICNNGeometricQuantileMap,
    x_point: torch.Tensor,
    radius: float,
    margin: float,
    device: torch.device,
    n_samples: int = 20000,
    n_boundary: int = 5000,
    margin_factor: float = 1.5,
) -> Dict:
    """
    MC estimation of volume in Y-space for GCQR region using BALL SAMPLING.
    
    KEY INSIGHT: In high dimensions (d>10), hypercube sampling is extremely
    inefficient because V_ball / V_cube ~ 10^(-d). Instead, we:
      1. Sample uniformly from a ball slightly larger than the expected region
      2. Test membership in the final GCQR region using OPTIMIZATION-based distance
      3. Volume = inside_ratio × ball_volume
    
    This gives ~10-50% inside_ratio instead of ~0% with hypercube sampling.
    
    GCQR Membership Test (Equation 8):
    ─────────────────────────────────────────────────────────────────────────
    - C_base = Q(x, (1-α)B^d) = {Q(x,u) : ||u|| <= radius}
    - score = signed_distance(y, ∂C_base)  (+ outside, - inside)
    
    UNIFIED FORMULA: y ∈ C_final ⟺ signed_dist(y) <= margin
    - margin >= 0 (expansion): covers base + external shell
    - margin < 0 (contraction): covers only deep interior
    ─────────────────────────────────────────────────────────────────────────
    
    Args:
        model_Q: quantile model Q(x,u) -> y
        x_point: single input point (d_x,)
        radius: base radius (1 - α) for GCQR
        margin: calibrated margin (q_hat from conformal calibration)
        device: torch device
        n_samples: number of MC samples for volume estimation
        n_boundary: number of boundary points for region approximation (unused, kept for API)
        margin_factor: multiplicative factor for sampling ball radius
        
    Returns:
        Dict with volume metrics
    """
    model_Q.eval()
    d_y = model_Q.y_dim
    x_point = x_point.to(device)
    
    # =========================================================================
    # Step 1: Compute boundary of C_base = Q(x, radius * B^d)
    # =========================================================================
    boundary = approximate_region_boundary(
        model_Q,
        x_point.unsqueeze(0),
        radius,
        n_boundary_samples=n_boundary,
    ).squeeze(0)  # (M, d_y)
    
    centroid = boundary.mean(dim=0)  # (d_y,)
    centroid_np = centroid.cpu().numpy()
    
    # =========================================================================
    # Step 2: Estimate region size from boundary
    # =========================================================================
    boundary_centered = boundary - centroid
    boundary_radii = torch.linalg.norm(boundary_centered, dim=-1)  # (M,)
    mean_radius_base = float(boundary_radii.mean())
    max_radius_base = float(boundary_radii.max())
    
    # =========================================================================
    # Step 3: Detect model collapse
    # =========================================================================
    collapse_warning = False
    if mean_radius_base < 1e-4:
        collapse_warning = True
        return {
            "volume": np.nan,
            "inside_ratio": np.nan,
            "sampling_ball_volume": np.nan,
            "collapse_warning": True,
            "center": centroid_np,
            "mean_radius_base": mean_radius_base,
        }
    
    # =========================================================================
    # Step 4: Define sampling ball radius (P1 FIX: handle expansion/contraction)
    # =========================================================================
    # For expansion (margin >= 0): effective_radius = max_radius_base + margin
    # For contraction (margin < 0): effective_radius = max_radius_base + margin (smaller)
    # 
    # BUG FIX: Previously, contraction kept effective_radius = max_radius_base,
    # which meant sampling from a much larger ball than the actual region.
    # In d=8 with r_region=0.009 and r_sample=0.048: hit rate = (0.009/0.048)^8 ≈ 0
    effective_radius = max_radius_base + margin  # Works for both expansion AND contraction
    effective_radius = max(effective_radius, 1e-6)  # Clamp to positive
    
    # Sampling ball radius with TIGHT safety factor
    # In high dimensions, even small factors cause huge volume blowup:
    #   (1/factor)^d → 0 quickly
    # For d=16: factor=1.5 → ratio=0.02%, factor=1.1 → ratio=21%
    margin_factor_tight = 1.0 + 0.1 * np.log(d_y + 1) / np.log(16)  # ~1.1 for d=16
    margin_factor_tight = min(margin_factor_tight, margin_factor)
    margin_factor_tight = max(margin_factor_tight, 1.05)  # At least 5% margin
    
    sampling_ball_radius = effective_radius * margin_factor_tight
    
    # Compute reference ball volume: V_d(r) = V_d(1) * r^d
    unit_ball_vol = _unit_ball_volume(d_y)
    sampling_ball_volume = unit_ball_vol * (sampling_ball_radius ** d_y)
    
    # =========================================================================
    # Step 5: MC sampling from a BALL (not hypercube!) centered on centroid
    # =========================================================================
    # Sample uniformly from d-dimensional ball using rejection-free method:
    # 1. Sample direction uniformly on sphere
    # 2. Sample radius with CDF r^d (so pdf ~ r^(d-1))
    rng = np.random.default_rng(seed=42)
    
    # Random directions (normalize Gaussian vectors)
    directions = rng.standard_normal((n_samples, d_y)).astype(np.float32)
    directions = directions / (np.linalg.norm(directions, axis=1, keepdims=True) + 1e-12)
    
    # Random radii with proper distribution for uniform ball sampling
    # CDF: F(r) = (r/R)^d, so r = R * U^(1/d) where U ~ Uniform(0,1)
    u_radii = rng.uniform(0, 1, size=(n_samples, 1)).astype(np.float32)
    radii = sampling_ball_radius * (u_radii ** (1.0 / d_y))
    
    # Combine to get uniform samples in ball
    Y_samples_np = centroid_np + directions * radii
    Y_samples = torch.from_numpy(Y_samples_np).to(device, non_blocking=True)
    
    # =========================================================================
    # Step 6: Compute signed distances via OPTIMIZATION (P0 FIX: robust method)
    # =========================================================================
    # Use gradient-based optimization instead of geometric heuristic
    # This is more reliable in high dimensions and for non-convex regions
    signed_dists_list = []
    batch_mc = 256  # Process in batches to avoid OOM
    
    for i in range(0, n_samples, batch_mc):
        y_batch = Y_samples[i:i+batch_mc]
        x_batch = x_point.unsqueeze(0).expand(len(y_batch), -1)
        
        # Use optimization-based distance (from GCQR_utils)
        with torch.enable_grad():
            signed_dist, _, _ = distance_to_region_uopt(
                Q=model_Q,
                x=x_batch,
                y=y_batch,
                r0=radius,
                n_sphere=128,
                n_ball=128,
                refine_steps=15,
                lr=0.05,
                eps_u=1e-3,
            )
        signed_dists_list.append(signed_dist.detach())
    
    signed_dists = torch.cat(signed_dists_list)
    
    # =========================================================================
    # Step 7: Apply GCQR Equation 8 membership test (UNIFIED FORMULA)
    # =========================================================================
    # Unified formula: y in C_final iff signed_dist <= margin
    # - margin >= 0 (expansion): covers base + external shell
    # - margin < 0 (contraction): covers only deep interior
    inside = (signed_dists <= margin).float()
    inside_ratio = float(inside.mean())
    
    # =========================================================================
    # Step 8: Compute volume = inside_ratio × sampling_ball_volume
    # =========================================================================
    volume = inside_ratio * sampling_ball_volume
    
    return {
        "volume": volume,
        "inside_ratio": inside_ratio,
        "sampling_ball_volume": sampling_ball_volume,
        "sampling_ball_radius": sampling_ball_radius,
        "effective_radius": effective_radius,
        "margin_factor_used": margin_factor_tight,
        "collapse_warning": collapse_warning,
        "center": centroid_np,
        "mean_radius_base": mean_radius_base,
        "max_radius_base": max_radius_base,
    }


# ==============================================================================
# Main Pipeline
# ==============================================================================

def run_gcqr3_pipeline(
    X_train: np.ndarray,
    Y_train: np.ndarray,
    X_cal: np.ndarray,
    Y_cal: np.ndarray,
    X_test: np.ndarray,
    Y_test: np.ndarray,
    alpha: float = 0.1,
    device: str = "cuda",
    batch_size: int = 256,
    width: int = 64,
    depth: int = 4,
    cfg: Optional[TrainCQRv2Config] = None,
    dataset_name: Optional[str] = None,
    n_boundary_samples: int = 256,
    mu: float = 0.01,
    tr: Optional[Dict] = None,
    tr_cond: Optional["ConditionalTR"] = None,  # NEW: conditional TR object
    volume_correction_method: str = "jacobian",  # "jacobian", "retransform", or "none"
    n_vol_samples: int = 20,
    n_mc_samples: int = 10000,  # NEW: MC samples for volume estimation
    return_model: bool = False,  # NEW: return trained model for plotting
) -> Dict:
    """
    Full GCQR3 Pipeline.
    
    1. Train PICNN Q(x, u) using Phase A from train_cqr_v2
    2. Define initial region as Q(x, (1-α)B^d)
    3. Compute nonconformity scores s(x,y) = d(y, Q(x, (1-α)B^d))
    4. Calibrate margin q via split conformal prediction
    5. Evaluate coverage and compute metrics
    
    Args:
        X_train, Y_train: training data
        X_cal, Y_cal: calibration data
        X_test, Y_test: test data
        alpha: miscoverage level
        device: torch device
        batch_size: batch size
        width, depth: PICNN architecture
        cfg: TrainCQRv2Config
        dataset_name: name for WSC delta selection
        n_boundary_samples: samples for boundary approximation
        mu: strong convexity parameter for PICNN
        tr: transformation dict from fit_tr_and_standardize (None if no TR transform)
        tr_cond: ConditionalTR object for X-dependent transformation (None if not used)
        volume_correction_method: how to correct volume for TR transform
            - "jacobian": multiply volume by |det(sqrtS)| (fast, recommended)
            - "retransform": retransform center to original space (for interpretability)
            - "none": no correction (volume in transformed space)
    
    Returns:
        Dict with metrics
    """
    device = torch.device(device if torch.cuda.is_available() else "cpu")
    
    d_x = X_train.shape[1]
    d_y = Y_train.shape[1]
    
    # Adapt n_boundary_samples to dimension (256 is too few for d > 8)
    # Formula: max(256, 500 * d_y) ensures good sphere coverage in high-d
    n_boundary_samples_adapted = max(n_boundary_samples, 500 * d_y)
    n_boundary_samples = n_boundary_samples_adapted
    
    # Adapt n_mc_samples to dimension (exponential scaling)
    # MC volume estimation needs many more samples in high dimensions
    if d_y > 10:
        factor = 2 ** ((d_y - 10) / 10)
        n_mc_samples_adapted = int(n_mc_samples * factor)
    else:
        n_mc_samples_adapted = n_mc_samples
    n_mc_samples = n_mc_samples_adapted
    
    print("=" * 70)
    print(" GCQR3 Pipeline")
    print("=" * 70)
    print(f"  d_x={d_x}, d_y={d_y}, α={alpha}")
    print(f"  Architecture: width={width}, depth={depth}")
    print(f"  n_boundary_samples: {n_boundary_samples} (adapted for d_y={d_y})")
    print(f"  n_mc_samples: {n_mc_samples} (adapted for d_y={d_y})")
    print("mu:", mu)
    # Default config
    if cfg is None:
        cfg = TrainCQRv2Config(
            epochs_A=100,
            lr_Q=1e-4,
            n_u_geom=16,
            p_shell=0.6,
            rho_shell=0.95,
            jitter_shell=0.02,
            clip_grad=5.0,
            amp=True,
            log_every=10,
        )
    
    # 1. Build PICNN model
    print("\n[1] Building PICNN model...")
    model_Q = ICNNGeometricQuantileMap(
        x_dim=d_x,
        u_dim=d_y,
        y_dim=d_y,
        width=width,
        depth=depth,
        mu=mu,
    ).to(device)
    
    # 2. Train Q(x, u) using Phase A
    print("\n[2] Training Q(x, u) via geometric quantile loss...")
    train_loader = DataLoader(
        TensorDataset(
            torch.from_numpy(X_train.astype(np.float32)),
            torch.from_numpy(Y_train.astype(np.float32)),
        ),
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=(device.type == "cuda"),
    )
    
    model_Q = train_phase_A_Q_only(model_Q, train_loader, device, d_y, cfg)
    
    # 3. Create ConformalPICNN and calibrate
    print("\n[3] Calibrating conformal margin...")
    X_cal_t = torch.from_numpy(X_cal.astype(np.float32))
    Y_cal_t = torch.from_numpy(Y_cal.astype(np.float32))
    
    cp = ConformalPICNN(
        model_Q,
        alpha=alpha,
        n_boundary_samples=n_boundary_samples,
        batch_size=batch_size,
    )
    result = cp.calibrate(X_cal_t, Y_cal_t)
    
    print(f"  Base radius (1-α): {result.radius:.4f}")
    print(f"  Calibrated margin q: {result.margin:.4f}")
    # Diagnostic: Check if model collapsed
    with torch.no_grad():
        x_sample = X_cal_t[:10].to(device)
        u_zero = torch.zeros(10, d_y, device=device)
        u_edge = torch.randn(10, d_y, device=device)
        u_edge = 0.9 * u_edge / u_edge.norm(dim=-1, keepdim=True)
        
        q_zero = model_Q(x_sample, u_zero)
        q_edge = model_Q(x_sample, u_edge)
        
        spread = (q_edge - q_zero).norm(dim=-1).mean()
        print(f"  [Diagnostic] Q spread: {spread:.4f} (should be > 0.5 for non-collapsed model)")
        if spread < 0.1:
            print("spread:", spread, "⚠️ WARNING: Model appears collapsed! Increase width/depth.")
    # 4. Evaluate on test set
    print("\n[4] Evaluating on test set...")
    X_test_t = torch.from_numpy(X_test.astype(np.float32))
    Y_test_t = torch.from_numpy(Y_test.astype(np.float32))
    
    coverage, is_covered = cp.evaluate(X_test_t, Y_test_t)
    print(f"  Coverage: {coverage:.4f} (target: {1 - alpha:.2f})")
    
    # 5. Worst Slab Coverage (WSC)
    print("\n[5] Computing Worst Slab Coverage (WSC)...")
    wsc_min = np.nan
    
    if WSC_AVAILABLE:
        # Select delta based on dataset
        delta_map = {
            "scm20d": 0.10,
            "sgemm": 0.01,
            "rf1": 0.10,
        }
        delta = delta_map.get(dataset_name, 0.10)
        M_directions = 100
        
        try:
            coverage_status = is_covered.numpy().astype(np.float32)
            wsc_min, wsc_list = calculer_wsc_regression(
                X_test, coverage_status, delta=delta, M=M_directions, random_state=42
            )
            print(f"  WSC_min: {wsc_min:.4f} (delta={delta})")
        except Exception as e:
            print(f"  WSC computation failed: {e}")
    else:
        print("  WSC not available")
    
    # 6. Volume estimation - Ball sampling (efficient for high dimensions)
    print("\n[6] Volume estimation (ball sampling, efficient for high-d)...")
    n_vol_samples = min(n_vol_samples, len(X_test_t))
    indices = np.random.choice(np.arange(len(X_test_t)), size=n_vol_samples, replace=False)
    all_vol_results = []
    
    print(f"  Config: n_points={n_vol_samples}, n_mc_samples={n_mc_samples}")
    print(f"  Base region: radius={result.radius:.4f}")
    print(f"  Margin: {result.margin:.4f}")
    print(f"  NOTE: margin_factor auto-adjusted for d={d_y} to avoid volume blowup")
    print(f"")
    print(f"  {'Pt':>3} | {'inside_ratio':>12} | {'volume':>14} | {'r_base':>8} | {'r_eff':>8} | {'mf':>5}")
    print(f"  {'-'*3}-+-{'-'*12}-+-{'-'*14}-+-{'-'*8}-+-{'-'*8}-+-{'-'*5}")
    
    for i in range(n_vol_samples):
        x_point = X_test_t[indices[i]]
        
        # Estimation via ball sampling (efficient for high dimensions)
        vol_res = estimate_region_volume_mc_gcqr(
            model_Q,
            x_point,
            radius=result.radius,
            margin=result.margin,
            device=device,
            n_samples=n_mc_samples,
            n_boundary=2000,
            margin_factor=1.5,
        )
        all_vol_results.append(vol_res)
        
        # Per-point log
        if not vol_res.get('collapse_warning', False):
            r_base = vol_res.get('mean_radius_base', 0)
            r_eff = vol_res.get('effective_radius', 0)
            mf = vol_res.get('margin_factor_used', 0)
            print(f"  {i+1:3d} | {vol_res['inside_ratio']:12.4f} | "
                  f"{vol_res['volume']:14.6e} | {r_base:8.4f} | {r_eff:8.4f} | {mf:5.2f}")
        else:
            print(f"  {i+1:3d} | {'COLLAPSE':>12} | {'N/A':>14} | {'N/A':>8} | {'N/A':>8} | {'N/A':>5}")
    
    print(f"  {'-'*3}-+-{'-'*12}-+-{'-'*14}-+-{'-'*8}-+-{'-'*8}-+-{'-'*5}")
    
    # Filter out collapsed results
    valid_results = [r for r in all_vol_results if not r.get('collapse_warning', False)]
    
    # =========================================================================
    # Volume correction for TR transformation (LOG-SPACE for numerical stability)
    # =========================================================================
    # Instead of computing |det(sqrtS)| which can overflow (e.g., 10^33),
    # we compute log|det(sqrtS)| = sum(log|eigenvalues|) and apply correction
    # directly to log-volume: (1/d) log V_orig = (1/d) log V_transf + (1/d) log|det|
    
    log_jacobian = 0.0  # log|det(sqrtS)|, 0 means no correction
    jacobian_per_dim = 1.0  # |det(sqrtS)|^(1/d) = geometric mean of eigenvalues
    condition_number = 1.0
    tr_type = "none"  # Track which TR type is used
    
    if tr_cond is not None and volume_correction_method != "none":
        # =====================================================================
        # CONDITIONAL TR: average log-jacobian across test points
        # =====================================================================
        tr_type = "conditional"
        print(f"  [Conditional TR Correction] method={volume_correction_method}")
        
        # Compute log-jacobian for each test point used in volume estimation
        log_jacobians = []
        cond_nums = []
        for i, res in enumerate(valid_results):
            x_idx = indices[i] if i < len(indices) else 0
            x_point = X_test[x_idx]
            log_jac_i = tr_cond.get_log_det_jacobian(x_point)
            cond_num_i = tr_cond.get_condition_number(x_point)
            log_jacobians.append(log_jac_i)
            cond_nums.append(cond_num_i)
            # Store per-point jacobian for potential per-point correction
            res['log_jacobian_local'] = log_jac_i
        
        # Use mean log-jacobian for aggregate volume correction
        log_jacobian = float(np.mean(log_jacobians))
        jacobian_per_dim = float(np.exp(log_jacobian / d_y))
        condition_number = float(np.mean(cond_nums))
        
        print(f"    log|det(sqrtS(x))| across {len(log_jacobians)} points:")
        print(f"      mean = {log_jacobian:+.2f}, std = {np.std(log_jacobians):.2f}")
        print(f"      range = [{min(log_jacobians):+.2f}, {max(log_jacobians):+.2f}]")
        print(f"    |det|^(1/d) (mean)  = {jacobian_per_dim:.4f} (scale factor per dim)")
        print(f"    Condition number (mean) = {condition_number:.2e}")
        
        if condition_number > 1e6:
            print(f"    ⚠️ WARNING: High condition number! Local covariance is ill-conditioned.")
        
        if volume_correction_method == "retransform":
            # Also retransform centers to original space for interpretability
            for i, res in enumerate(valid_results):
                if 'center' in res and res['center'] is not None:
                    x_idx = indices[i] if i < len(indices) else 0
                    x_point = X_test[x_idx]
                    center_orig = tr_cond.inverse_transform(
                        x_point.reshape(1, -1), 
                        res['center'].reshape(1, -1)
                    ).squeeze()
                    res['center_original'] = center_orig
                    res['center_transformed'] = res['center']
    
    elif tr is not None and volume_correction_method != "none":
        # =====================================================================
        # GLOBAL TR: single jacobian for all points (original behavior)
        # =====================================================================
        tr_type = "global"
        
        # Compute eigenvalues of sqrtS for stable log-determinant
        sqrtS = tr["sqrtS"]
        eigvals = np.linalg.eigvalsh(sqrtS)  # Real eigenvalues (sqrtS is symmetric)
        eigvals_abs = np.abs(eigvals)
        
        # Numerical stability: clip tiny eigenvalues
        eigvals_clipped = np.clip(eigvals_abs, 1e-30, None)
        
        # log|det(sqrtS)| = sum(log|λi|)
        log_jacobian = float(np.sum(np.log(eigvals_clipped)))
        
        # Geometric mean = |det|^(1/d) = exp((1/d) * log|det|)
        jacobian_per_dim = float(np.exp(log_jacobian / d_y))
        
        # Condition number for diagnostics
        condition_number = float(eigvals_clipped.max() / eigvals_clipped.min())
        
        print(f"  [Global TR Correction] method={volume_correction_method}")
        print(f"    log|det(sqrtS)|     = {log_jacobian:+.2f}")
        print(f"    |det|^(1/d)         = {jacobian_per_dim:.4f} (scale factor per dim)")
        print(f"    Condition number    = {condition_number:.2e}")
        print(f"    Eigenvalue range    = [{eigvals_clipped.min():.2e}, {eigvals_clipped.max():.2e}]")
        
        if condition_number > 1e6:
            print(f"    ⚠️ WARNING: High condition number! Covariance is ill-conditioned.")
        
        if volume_correction_method == "retransform":
            # Also retransform centers to original space for interpretability
            from tr_retr import tr_retransform
            for res in valid_results:
                if 'center' in res and res['center'] is not None:
                    center_orig = tr_retransform(res['center'].reshape(1, -1), tr).squeeze()
                    res['center_original'] = center_orig
                    res['center_transformed'] = res['center']
    
    if len(valid_results) > 0:
        # Aggregate results
        volumes_transformed = np.array([r['volume'] for r in valid_results])
        inside_ratios = np.array([r['inside_ratio'] for r in valid_results])
        mean_r_base = np.mean([r.get('mean_radius_base', 0) for r in valid_results])
        mean_r_eff = np.mean([r.get('effective_radius', 0) for r in valid_results])
        mean_mf = np.mean([r.get('margin_factor_used', 0) for r in valid_results])
        
        mean_volume_transformed = float(np.mean(volumes_transformed))
        mean_inside_ratio = float(np.mean(inside_ratios))
        
        # =====================================================================
        # Compute normalized log-volume WITH PER-POINT CORRECTION
        # =====================================================================
        # CORRECT APPROACH: For each point i, compute corrected volume in original space
        #   V_orig_i = V_transf_i × |det(sqrtS(x_i))|
        # Then average in log-space:
        #   norm_log_vol = (1/d) × mean(log V_orig_i)
        #                = (1/d) × mean(log V_transf_i + log|det(sqrtS_i)|)
        #
        # This is more accurate than: (1/d) × log(mean(V_transf)) + mean(log|det|)/d
        # especially when there's variance in jacobians across points.
        
        eps_vol = 1e-30
        
        if (tr_cond is not None or tr is not None) and volume_correction_method != "none":
            # Per-point volume correction
            log_vols_corrected = []
            for i, res in enumerate(valid_results):
                log_vol_transf_i = np.log(max(res['volume'], eps_vol))
                
                if tr_cond is not None:
                    # Conditional TR: use per-point jacobian
                    log_jac_i = res.get('log_jacobian_local', 0.0)
                else:
                    # Global TR: same jacobian for all points
                    log_jac_i = log_jacobian
                
                log_vol_corrected_i = log_vol_transf_i + log_jac_i
                log_vols_corrected.append(log_vol_corrected_i)
                res['log_vol_corrected'] = log_vol_corrected_i
            
            # Average in log-space, then normalize by d
            mean_log_vol_corrected = float(np.mean(log_vols_corrected))
            norm_log_vol_final = mean_log_vol_corrected / d_y
            
            # Also compute the old way for comparison
            norm_log_vol_transformed = (1 / d_y) * np.log(max(mean_volume_transformed, eps_vol))
            norm_log_vol_old = norm_log_vol_transformed + (log_jacobian / d_y)
            
            # For display: mean corrected volume (may overflow)
            mean_volume = float(np.exp(mean_log_vol_corrected)) if mean_log_vol_corrected < 700 else np.inf
            
            print(f"  {'Avg':>3} | {mean_inside_ratio:12.4f} | "
                  f"{mean_volume_transformed:14.6e} | {mean_r_base:8.4f} | {mean_r_eff:8.4f} | {mean_mf:5.2f}")
            
            print(f"")
            print(f"  Summary (d_y={d_y}):")
            print(f"    Valid samples:       {len(valid_results)}/{n_vol_samples}")
            print(f"    Mean inside ratio:   {mean_inside_ratio*100:.2f}%")
            print(f"    Mean vol (transf):   {mean_volume_transformed:.6e}")
            print(f"    TR type:             {tr_type}")
            print(f"    Per-point correction: YES (more accurate)")
            print(f"    Mean vol (original): {mean_volume:.6e}" if mean_volume != np.inf else f"    Mean vol (original): overflow (use log)")
            print(f"    (1/d) log vol (transf):  {norm_log_vol_transformed:+.4f}")
            print(f"    (1/d) log vol (per-pt):  {norm_log_vol_final:+.4f}  <-- FINAL (per-point correction)")
            print(f"    (1/d) log vol (mean-jac):{norm_log_vol_old:+.4f}  (old method, for comparison)")
            if abs(norm_log_vol_final - norm_log_vol_old) > 0.1:
                print(f"    ⚠️ NOTE: Large diff between methods ({norm_log_vol_final - norm_log_vol_old:+.2f})")
        else:
            # No TR correction
            norm_log_vol_transformed = (1 / d_y) * np.log(max(mean_volume_transformed, eps_vol))
            norm_log_vol_final = norm_log_vol_transformed
            mean_volume = mean_volume_transformed
            
            print(f"  {'Avg':>3} | {mean_inside_ratio:12.4f} | "
                  f"{mean_volume_transformed:14.6e} | {mean_r_base:8.4f} | {mean_r_eff:8.4f} | {mean_mf:5.2f}")
            
            print(f"")
            print(f"  Summary (d_y={d_y}):")
            print(f"    Valid samples:       {len(valid_results)}/{n_vol_samples}")
            print(f"    Mean inside ratio:   {mean_inside_ratio*100:.2f}%")
            print(f"    (1/d) log vol:       {norm_log_vol_final:+.4f}")
    else:
        print(f"  WARNING: All samples collapsed!")
        mean_volume = np.nan
        mean_volume_transformed = np.nan
        mean_inside_ratio = np.nan
        norm_log_vol_final = np.nan
        norm_log_vol_transformed = np.nan
    
    output = {
        "coverage": coverage,
        "target": 1 - alpha,
        "base_radius": result.radius,
        "margin": result.margin,
        "vol_final": mean_volume if 'mean_volume' in dir() else np.nan,
        "vol_transformed": mean_volume_transformed if 'mean_volume_transformed' in dir() else np.nan,
        "inside_ratio": mean_inside_ratio,
        "norm_log_vol_final": norm_log_vol_final,
        "norm_log_vol_transformed": norm_log_vol_transformed if 'norm_log_vol_transformed' in dir() else np.nan,
        "wsc_min": wsc_min,
        "cal_scores_mean": float(result.cal_scores.mean()),
        "cal_scores_std": float(result.cal_scores.std()),
        "log_jacobian": log_jacobian,
        "jacobian_per_dim": jacobian_per_dim,
        "condition_number": condition_number,
        "tr_type": tr_type,
        "volume_correction_method": volume_correction_method if (tr is not None or tr_cond is not None) else "none",
    }
    
    # Optionally return the trained model for plotting
    if return_model:
        output["model_Q"] = model_Q
    
    return output


# ==============================================================================
# Results Saving
# ==============================================================================

def save_results_to_csv(
    results_list: list,
    output_dir: str = "results",
    filename: str = "results_conformal_picnn.csv",
) -> str:
    """Save results to CSV file (simplified: 5 key columns only)."""
    os.makedirs(output_dir, exist_ok=True)
    
    df = pd.DataFrame(results_list)
    
    # Only keep the 5 required columns
    keep_cols = ['dataset', 'coverage', 'wsc_min', 'norm_log_vol_scaled', 'norm_log_vol_transformed']
    df = df[[c for c in keep_cols if c in df.columns]]
    
    output_path = os.path.join(output_dir, filename)
    df.to_csv(output_path, index=False, float_format='%.6f')
    
    print(f"\n✓ Results saved to: {output_path}")
    return output_path


# ==============================================================================
# Main
# ==============================================================================

if __name__ == "__main__":
    # Configuration matching GCQR4
    from tr_retr import fit_tr_and_standardize, tr_transform, tr_retransform
    from tr_retr_conditional import (
        ConditionalTR, 
        fit_conditional_tr_and_standardize,
        conditional_tr_transform,
        conditional_tr_retransform,
        StabilizedConditionalTR,
        fit_stabilized_conditional_tr_and_standardize,
    )
    datasets = ["scm20d", "rf1", "sgemm", "rf2", "scm1d"]
    alpha = 0.1
    seed = 42
    tr_retr_bool = True
    mode = "tr_retr" if tr_retr_bool else "standard"
    ridge = 1e-3
    # NEW: Conditional TR (depends on X) vs Global TR
    # - "conditional": uses KNN-based local covariance Sigma(x)
    # - "stabilized": conditional + shrinkage + fallback (recommended for rf1)
    # - "global": uses global covariance (original behavior)
    tr_mode = "stabilized"  # "conditional", "stabilized", or "global"
    tr_knn_k = 100  # number of neighbors for conditional TR
    tr_shrinkage_alpha = 0.5 # blend factor with global cov (0=local, 1=global)
    tr_fallback_threshold = 1e6  # condition number threshold for fallback
    
    # Volume correction method when using TR transform:
    # - "jacobian": multiply volume by |det(sqrtS)| (fast, recommended)
    # - "retransform": same as jacobian + retransform centers (for interpretability)
    # - "none": no correction (volume in transformed/whitened space)
    volume_correction_method = "jacobian"
    
    # Dataset-specific parameters - UPDATED for rf1 to prevent Q collapse
    test_size = {"scm20d": 0.2, "sgemm": 0.2, "rf1": 0.2, "rf2": 0.2, "scm1d": 0.2}
    cal_size = {"scm20d": 0.4, "sgemm": 0.4, "rf1": 0.4, "rf2": 0.4, "scm1d": 0.4}
    
    # Width/depth: tuned per dataset via grid search
    width = {"scm20d": 112, "sgemm": 112, "rf1": 48, "rf2": 16, "scm1d": 112}
    depth = {"scm20d": 6, "sgemm": 6, "rf1": 4, "rf2": 2, "scm1d": 7}
    
    # mu (convexity parameter): lower = more flexibility, higher = more stability
    # rf1 benefits from slightly higher mu to prevent overfitting
    # [FIX] Increased scm20d mu from 0.001 to 0.01 to reduce Q-spread and volume
    mu_config = {"scm20d": 0.005, "sgemm": 0.005, "rf1": 0.005, "rf2": 0.005, "scm1d": 0.005}

    MAX_SAMPLES = 25000
    
    print("=" * 70)
    print(" ConformalPICNN Evaluation")
    print(" (Same setup as GCQR4)")
    print("=" * 70)
    
    all_results = []
    
    for dataset in datasets:
        print(f"\n{'='*70}")
        print(f" Dataset: {dataset}")
        print(f"{'='*70}")
        
        # Load dataset
        print(f"\nLoading dataset: {dataset}")
        if dataset == "rf1":
            dataset_result = load_dataset(dataset, use_rf1_resplit=True, random_state=seed)
            X, Y = dataset_result
        else:
            X, Y = load_dataset(dataset, random_state=seed)
        
        X = X.values.astype(np.float32) if hasattr(X, "values") else np.asarray(X, dtype=np.float32)
        Y = np.asarray(Y, dtype=np.float32)
        
        # Subsample if needed
        n_samples = X.shape[0]
        if MAX_SAMPLES is not None and n_samples > MAX_SAMPLES:
            print(f"  Subsampling to {MAX_SAMPLES} samples...")
            np.random.seed(seed)
            indices = np.random.choice(n_samples, MAX_SAMPLES, replace=False)
            X = X[indices]
            Y = Y[indices]
        
        # Split (same as GCQR4)
        X_train, X_cal, X_test, Y_train, Y_cal, Y_test = split_train_cal_test(
            X, Y,
            test_size=test_size[dataset],
            cal_size=cal_size[dataset],
            random_state=seed,
        )
        
        # Standardize (same as GCQR4)
        if dataset != "sgemm":
            scaler_X = StandardScaler()
            X_train = scaler_X.fit_transform(X_train).astype(np.float32)
            X_cal = scaler_X.transform(X_cal).astype(np.float32)
            X_test = scaler_X.transform(X_test).astype(np.float32)

        # TR transformation (whitening) or standard scaling
        tr = None  # Will be set if tr_retr_bool is True (global TR)
        tr_cond = None  # Will be set if using conditional TR
        
        if tr_retr_bool:
            if tr_mode == "stabilized":
                # =====================================================
                # STABILIZED CONDITIONAL TR: shrinkage + fallback
                # =====================================================
                print(f"  Using STABILIZED CONDITIONAL TR (k={tr_knn_k}, shrinkage={tr_shrinkage_alpha}, fallback_thresh={tr_fallback_threshold:.0e})...")
                tr_cond, Y_train = fit_stabilized_conditional_tr_and_standardize(
                    X_train, Y_train, 
                    method="knn", 
                    k=tr_knn_k,
                    ridge=ridge,
                    shrinkage_alpha=tr_shrinkage_alpha,
                    fallback_threshold=tr_fallback_threshold,
                    eigenvalue_shrinkage=0.0, 
                )
                Y_cal = conditional_tr_transform(X_cal, Y_cal, tr_cond)
                Y_test = conditional_tr_transform(X_test, Y_test, tr_cond)
                
                # Diagnostics
                print(f"  Stabilized Conditional TR diagnostics:")
                sample_idx = np.random.choice(len(X_train), min(10, len(X_train)), replace=False)
                log_dets = []
                cond_nums = []
                for idx in sample_idx:
                    log_det_i = tr_cond.get_log_det_jacobian(X_train[idx])
                    cond_num_i = tr_cond.get_condition_number(X_train[idx])
                    log_dets.append(log_det_i)
                    cond_nums.append(cond_num_i)
                print(f"    log|det(sqrtS(x))|: mean={np.mean(log_dets):+.2f}, std={np.std(log_dets):.2f}")
                print(f"    Condition numbers:  mean={np.mean(cond_nums):.2e}, max={np.max(cond_nums):.2e}")
                tr_cond.print_diagnostics()
                
            elif tr_mode == "conditional":
                # =====================================================
                # CONDITIONAL TR: Sigma(x) depends on X via KNN
                # =====================================================
                print(f"  Using CONDITIONAL TR (KNN, k={tr_knn_k})...")
                tr_cond, Y_train = fit_conditional_tr_and_standardize(
                    X_train, Y_train, 
                    method="knn", 
                    k=tr_knn_k,
                    ridge=ridge,
                )
                Y_cal = conditional_tr_transform(X_cal, Y_cal, tr_cond)
                Y_test = conditional_tr_transform(X_test, Y_test, tr_cond)
                
                # Diagnostics: sample a few points to show local covariance variation
                print(f"  Conditional TR diagnostics (sampled points):")
                sample_idx = np.random.choice(len(X_train), min(5, len(X_train)), replace=False)
                log_dets = []
                cond_nums = []
                for idx in sample_idx:
                    log_det_i = tr_cond.get_log_det_jacobian(X_train[idx])
                    cond_num_i = tr_cond.get_condition_number(X_train[idx])
                    log_dets.append(log_det_i)
                    cond_nums.append(cond_num_i)
                print(f"    log|det(sqrtS(x))|: mean={np.mean(log_dets):+.2f}, std={np.std(log_dets):.2f}")
                print(f"    Condition numbers:  mean={np.mean(cond_nums):.2e}, max={np.max(cond_nums):.2e}")
                
            else:
                # =====================================================
                # GLOBAL TR: single Sigma for all points (original)
                # =====================================================
                print(f"  Using GLOBAL TR...")
                tr, Y_train = fit_tr_and_standardize(Y_train, ridge=ridge)
                Y_cal = tr_transform(Y_cal, tr)
                Y_test = tr_transform(Y_test, tr)
                # Compute stable log-jacobian for diagnostics
                eigvals = np.abs(np.linalg.eigvalsh(tr['sqrtS']))
                log_det = np.sum(np.log(np.clip(eigvals, 1e-30, None)))
                scale_per_dim = np.exp(log_det / Y_train.shape[1])
                cond_num = eigvals.max() / eigvals.min()
                print(f"  TR transform applied. sqrtS shape: {tr['sqrtS'].shape}")
                print(f"    log|det(sqrtS)|  = {log_det:+.2f}")
                print(f"    |det|^(1/d)      = {scale_per_dim:.4f} (avg scale per dim)")
                print(f"    Condition number = {cond_num:.2e}")
                if cond_num > tr_fallback_threshold:
                    print(f"    ⚠️ WARNING: Ill-conditioned covariance!")
        else:
            scaler_Y = StandardScaler()
            Y_train = scaler_Y.fit_transform(Y_train).astype(np.float32)
            Y_cal = scaler_Y.transform(Y_cal).astype(np.float32)
            Y_test = scaler_Y.transform(Y_test).astype(np.float32)
        print("  Data standardized.")
        i = np.argmax(np.linalg.norm(Y_train, axis=1))
        y_max = Y_train[i]
        norm_max_train = np.linalg.norm(y_max)
        i = np.argmax(np.linalg.norm(Y_cal, axis=1))
        y_max = Y_cal[i]
        norm_max_cal = np.linalg.norm(y_max)
        print("norm_max_train:", norm_max_train, "norm_max_cal:", norm_max_cal)
        scale_factor = (1/norm_max_train)
        center = Y_train.mean(axis=0)
        Y_train = (Y_train) * scale_factor
        Y_cal = (Y_cal) * scale_factor
        Y_test = (Y_test) * scale_factor

        print(f"Shapes: train={X_train.shape}, cal={X_cal.shape}, test={X_test.shape}")
        
        # Get training config (same as GCQR4)
        cfg = get_train_config_gcqr(dataset)
        
        # Get mu from config (if not present, use default from mu_config)
        mu = mu_config.get(dataset, 0.0001)
        
        # Run pipeline
        results = run_gcqr3_pipeline(
            X_train, Y_train,
            X_cal, Y_cal,
            X_test, Y_test,
            alpha=alpha,
            width=width[dataset],
            depth=depth[dataset],
            mu=mu,  # Pass mu to prevent Q collapse
            cfg=cfg,
            dataset_name=dataset,
            tr=tr,
            tr_cond=tr_cond,
            volume_correction_method=volume_correction_method,
        )
        
        # Add metadata
        results['dataset'] = dataset
        results['alpha'] = alpha
        results['seed'] = seed
        results['d_x'] = X_train.shape[1]
        results['d_y'] = Y_train.shape[1]
        results['n_train'] = X_train.shape[0]
        results['n_cal'] = X_cal.shape[0]
        results['n_test'] = X_test.shape[0]
        results['tr_mode'] = tr_mode if tr_retr_bool else "none"
        results['tr_knn_k'] = tr_knn_k if (tr_retr_bool and tr_mode in ["conditional", "stabilized"]) else None
        results['norm_log_vol_scaled'] = results['norm_log_vol_transformed'] + np.log(norm_max_train)
        all_results.append(results)
        print("correction", np.log(norm_max_train))
        # Print summary
        print("\n" + "=" * 70)
        print(f" GCQR3 RESULTS - {dataset}")
        print("=" * 70)
        for k, v in results.items():
            if isinstance(v, float):
                print(f"  {k}: {v:.4f}")
            else:
                print(f"  {k}: {v}")
    
    # Save results
    print("\n" + "=" * 70)
    print(" SAVING RESULTS")
    print("=" * 70)
    save_results_to_csv(all_results, output_dir="results_gcqr", filename=f"results_gcqr3_{mode}.csv")
    
    # Print comparison table
    print("\n" + "=" * 70)
    print(" SUMMARY TABLE")
    print("=" * 70)
    df = pd.DataFrame(all_results)
    print(df[['dataset', 'coverage', 'margin', 'norm_log_vol_scaled', 'wsc_min', 'norm_log_vol_transformed']].to_string(index=False))
