from __future__ import annotations

import os
import time
import warnings
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Optional
from sklearn.preprocessing import StandardScaler, RobustScaler
from datetime import datetime

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)

# Import PyTorch
try:
    import torch
    HAS_TORCH = True
except ImportError:
    torch = None
    HAS_TORCH = False
    print("Warning: PyTorch not available. VSPS evaluation not possible.")

# Import VSPS utilities
try:
    from VSPS_utils import (
        eval_vsps_cp,
        sample_y_and_logdet_fwd,
        select_topk_centers,
        min_dist_to_centers,
    )
    HAS_VSPS = True
except ImportError:
    eval_vsps_cp = None
    HAS_VSPS = False
    print("Warning: VSPS_utils not available. VSPS evaluation not possible.")

# Import dataset loader
from load_dataset_grinsztajn import load_dataset

# Import split function
from train_cqr_v2_utils import split_train_cal_test

# Import TR utilities
from tr_retr import fit_tr_and_standardize, tr_transform, tr_retransform
from tr_retr_conditional import (
    ConditionalTR,
    fit_conditional_tr_and_standardize,
    conditional_tr_transform,
    conditional_tr_retransform,
    StabilizedConditionalTR,
    fit_stabilized_conditional_tr_and_standardize,
)

# Import WSC calculation
try:
    from worst_slab_cov_gpu import worst_slab_coverage_multireg_random_dirs_gpu
    WSC_AVAILABLE = True
except ImportError:
    WSC_AVAILABLE = False
    print("Warning: worst_slab_cov_gpu not available. WSC will not be computed.")


def compute_wsc_for_method(
    X_test: np.ndarray,
    coverage_status: np.ndarray,
    M: int = 500,
    delta: float = 0.1,
    min_points: int = 30,
    random_state: int = 42,
) -> float:
    """
    Compute Worst Slab Coverage for a conformal prediction method.
    """
    if not WSC_AVAILABLE:
        return np.nan

    try:
        wsc, _, _, _ = worst_slab_coverage_multireg_random_dirs_gpu(
            coverage_status=coverage_status.astype(float),
            Z=X_test,
            M=M,
            delta=delta,
            min_points=min_points,
            random_state=random_state,
        )
        return wsc
    except Exception as e:
        print(f"    Warning: WSC computation failed: {e}")
        return np.nan


def save_results_to_csv(
    results: List[Dict],
    output_dir: str = "results_vsps_tr",
    filename: str = "results_vsps_with_tr.csv",
):
    """Save results to CSV file (simplified: minimal columns only)."""
    os.makedirs(output_dir, exist_ok=True)
    filepath = os.path.join(output_dir, filename)

    df = pd.DataFrame(results)
    
    # Minimal columns for multi-run analysis
    keep_cols = ['run', 'seed', 'coverage', 'wsc', 'log_volume_normalized_scaled', 'log_volume_normalized_transformed', 'time']
    df = df[[c for c in keep_cols if c in df.columns]]
    
    df.to_csv(filepath, index=False, float_format='%.6f')
    print(f"  Results saved to: {filepath}")


def run_vsps_evaluation(
    X_train: np.ndarray,
    Y_train: np.ndarray,
    X_cal: np.ndarray,
    Y_cal: np.ndarray,
    X_test: np.ndarray,
    Y_test: np.ndarray,
    alpha: float = 0.1,
    dataset_name: str = "unknown",
    tr: Optional[Dict] = None,
    tr_cond: Optional[ConditionalTR] = None,
    norm_max_train: float = 1.0,
    scale_factor: float = 1.0,
    verbose: bool = True,
) -> Dict:
    """
    Run VSPS evaluation with proper preprocessing handling.

    Parameters
    ----------
    X_train, Y_train : Training data (in transformed/whitened space)
    X_cal, Y_cal : Calibration data (in transformed/whitened space)
    X_test, Y_test : Test data (in transformed/whitened space)
    alpha : Significance level
    dataset_name : Name of the dataset
    tr : Global TR transform dict (if using global TR)
    tr_cond : Conditional TR object (if using conditional/stabilized TR)
    norm_max_train : Max norm from Y_train (for volume correction)
    scale_factor : Scale factor applied to Y (1/norm_max_train)
    verbose : Print progress

    Returns
    -------
    Dict with metrics: coverage, volume, log_vol, wsc, etc.
    """
    if not HAS_VSPS or not HAS_TORCH:
        return {
            "coverage": np.nan,
            "target_coverage": 1.0 - alpha,
            "average_volume": np.nan,
            "log_volume_normalized": np.nan,
            "wsc": np.nan,
            "K_star": np.nan,
            "gamma": np.nan,
            "time": np.nan,
        }

    d_y = Y_train.shape[1]
    device = "cuda" if torch.cuda.is_available() else "cpu"

    print(f"\n  [VSPS] Volume-Sorted Prediction Sets (Normalizing Flow)...")
    print(f"    Device: {device}")
    print(f"    Y dimensions: d_y={d_y}")
    print(f"    Scale factor: {scale_factor:.6f}")

    vsps_start = time.time()

    try:
        metrics_vsps, vsps_state = eval_vsps_cp(
            X_train,
            Y_train,
            X_cal,
            Y_cal,
            X_test,
            Y_test,
            # Coverage
            alpha=alpha,
            # Flow training
            n_epochs=50,
            n_layers=8,
            hidden=(256, 256),
            lr=2e-4,
            batch_size_train=512,
            # VSPS sampling
            M=1024,
            K_grid=(1, 2, 4, 8, 16, 32),
            sort_ascending=True,
            # Volume estimation
            n_mc_volume=4096,
            sampling_radius=None,
            n_test_volume_points=min(128, X_test.shape[0]),
            # Validation split
            val_frac=0.5,
            device=device,
            verbose=verbose,
            seed=42,
        )

        vsps_time = time.time() - vsps_start

        print(f"    Coverage: {metrics_vsps['coverage']:.2%} (target: {1.0-alpha:.0%})")
        print(f"    Avg volume (transformed): {metrics_vsps['average_volume']:.4e}")
        print(f"    Log-vol normalized (transformed): {metrics_vsps['log_volume_normalized']:.4f}")
        print(f"    K* selected: {vsps_state.K_star}")
        print(f"    Gamma: {vsps_state.gamma:.4f}")
        print(f"    Time: {vsps_time:.2f}s")


        log_vol_transformed = metrics_vsps["log_volume_normalized"]

        # 1. Correct for scale factor
        # log(V_scaled) = log(V_transformed) + d_y * log(1/scale_factor)
        #               = log(V_transformed) + log(norm_max_train)
        log_scale_correction = np.log(norm_max_train)
        log_vol_scaled = log_vol_transformed + log_scale_correction

        # 2. Correct for TR transform (if using global TR)
        log_tr_correction = 0.0
        if tr is not None:
            # Global TR: single Jacobian for all points
            eigvals = np.abs(np.linalg.eigvalsh(tr["sqrtS"]))
            log_det = np.sum(np.log(np.clip(eigvals, 1e-30, None)))
            log_tr_correction = log_det
            print(f"    TR Jacobian correction: {log_tr_correction:+.4f}")
        elif tr_cond is not None:
            # Conditional TR: average Jacobian over test points
            # Sample a few test points to estimate average log|det|
            n_sample = min(100, X_test.shape[0])
            sample_idx = np.random.choice(X_test.shape[0], n_sample, replace=False)
            log_dets = []
            for idx in sample_idx:
                log_det_i = tr_cond.get_log_det_jacobian(X_test[idx])
                log_dets.append(log_det_i)
            log_tr_correction = np.mean(log_dets) / d_y
            print(
                f"    Conditional TR Jacobian correction: {log_tr_correction:+.4f} (avg over {n_sample} points)"
            )

        log_vol_corrected = log_vol_scaled + log_tr_correction

        print(f"    Log-vol corrected: {log_vol_corrected:.4f}")
        print(f"      (transformed: {log_vol_transformed:.4f} + scale: {log_scale_correction:+.4f} + TR: {log_tr_correction:+.4f})")

        # =====================================================
        # Compute WSC
        # =====================================================
        # Extract coverage status for WSC computation
        y_samp_vsps, ld_vsps = sample_y_and_logdet_fwd(
            vsps_state.flow,
            X_test,
            vsps_state.M_samples,
            device=device,
            dtype=torch.float32,
            batch_size=128,
        )
        centers_vsps = select_topk_centers(
            y_samp_vsps, ld_vsps, vsps_state.K_star, sort_ascending=vsps_state.sort_ascending
        ).numpy()
        test_scores_vsps = min_dist_to_centers(Y_test, centers_vsps)
        coverage_status_vsps = test_scores_vsps <= vsps_state.gamma

        # WSC parameters (dataset-specific)
        wsc_delta = {"scm20d": 0.1, "sgemm": 0.01, "rf1": 0.1, "rf2": 0.1, "scm1d": 0.1}
        wsc_M = 1000
        wsc_min_points = 30

        wsc = compute_wsc_for_method(
            X_test,
            coverage_status_vsps,
            M=wsc_M,
            delta=wsc_delta.get(dataset_name, 0.1),
            min_points=wsc_min_points,
        )
        print(f"    WSC: {wsc:.2%}" if not np.isnan(wsc) else "    WSC: N/A")

        return {
            "coverage": metrics_vsps["coverage"],
            "target_coverage": 1.0 - alpha,
            "average_volume_transformed": metrics_vsps["average_volume"],
            "log_volume_normalized_transformed": log_vol_transformed,
            "log_volume_normalized_scaled": log_vol_scaled,
            "log_volume_normalized_corrected": log_vol_corrected,
            "wsc": wsc,
            "K_star": vsps_state.K_star,
            "gamma": vsps_state.gamma,
            "time": vsps_time,
            "scale_factor": scale_factor,
            "log_scale_correction": log_scale_correction,
            "log_tr_correction": log_tr_correction,
        }

    except Exception as e:
        print(f"    VSPS evaluation failed: {e}")
        import traceback

        traceback.print_exc()
        return {
            "coverage": np.nan,
            "target_coverage": 1.0 - alpha,
            "average_volume_transformed": np.nan,
            "log_volume_normalized_transformed": np.nan,
            "log_volume_normalized_scaled": np.nan,
            "log_volume_normalized_corrected": np.nan,
            "wsc": np.nan,
            "K_star": np.nan,
            "gamma": np.nan,
            "time": np.nan,
            "scale_factor": scale_factor,
            "log_scale_correction": np.nan,
            "log_tr_correction": np.nan,
        }


# =============================================================================
# MAIN
# =============================================================================

if __name__ == "__main__":
    print("=" * 80)
    print(" VSPS EVALUATION WITH TR PREPROCESSING")
    print(" (Same preprocessing as GCQR_3.py)")
    print("=" * 80)

    # Configuration (mirroring GCQR_3.py)
    datasets = ["scm20d", "rf1", "sgemm", "rf2", "scm1d"]
    alpha = 0.1
    base_seed = 42
    n_runs = 1  # Number of runs per dataset
    seeds = [base_seed + i * 111 for i in range(n_runs)]  # Reproducible seeds
    tr_retr_bool = False
    mode = "tr_retr" if tr_retr_bool else "standard"
    ridge = 1e-3

    # TR mode: "conditional", "stabilized", or "global"
    tr_mode = "stabilized"
    tr_knn_k = 100
    tr_shrinkage_alpha = 0.5
    tr_fallback_threshold = 1e6

    # Dataset-specific parameters
    test_size = {"scm20d": 0.2, "sgemm": 0.2, "rf1": 0.2, "rf2": 0.2, "scm1d": 0.2}
    cal_size = {"scm20d": 0.4, "sgemm": 0.4, "rf1": 0.4, "rf2": 0.4, "scm1d": 0.4}

    MAX_SAMPLES = 25000

    # Print system info
    print("\n[SYSTEM INFO]")
    print(f"  NumPy version: {np.__version__}")
    if HAS_TORCH:
        print(f"  PyTorch version: {torch.__version__}")
        if torch.cuda.is_available():
            print(f"  CUDA available: {torch.cuda.get_device_name(0)}")
        else:
            print("  CUDA: Not available")
    else:
        print("  PyTorch: Not available")
    print(f"  VSPS available: {HAS_VSPS}")
    print(f"  WSC available: {WSC_AVAILABLE}")

    print("\n[CONFIGURATION]")
    print(f"  TR mode: {tr_mode}")
    print(f"  TR KNN k: {tr_knn_k}")
    print(f"  TR shrinkage alpha: {tr_shrinkage_alpha}")
    print(f"  Alpha (CP): {alpha}")
    print(f"  Max samples: {MAX_SAMPLES}")
    print(f"  Number of runs: {n_runs}")
    print(f"  Seeds: {seeds}")

    print("\n" + "=" * 80)

    for dataset in datasets:
        dataset_results = []  # Results for this dataset only
        print(f"\n{'='*80}")
        print(f" Dataset: {dataset}")
        print(f"{'='*80}")

        # Load dataset once (before run loop)
        print(f"\nLoading dataset: {dataset}")
        if dataset == "rf1":
            dataset_result = load_dataset(dataset, use_rf1_resplit=True, random_state=base_seed)
            X_full, Y_full = dataset_result
        else:
            X_full, Y_full = load_dataset(dataset, random_state=base_seed)

        X_full = X_full.values.astype(np.float32) if hasattr(X_full, "values") else np.asarray(X_full, dtype=np.float32)
        Y_full = np.asarray(Y_full, dtype=np.float32)
        print(f"  Full dataset shape: X={X_full.shape}, Y={Y_full.shape}")

        for run_idx, seed in enumerate(seeds):
            print(f"\n  --- Run {run_idx + 1}/{n_runs} (seed={seed}) ---")
            run_start_time = time.time()
            
            # Copy data for this run
            X = X_full.copy()
            Y = Y_full.copy()

            # Subsample if needed (with run-specific seed)
            n_samples = X.shape[0]
            if MAX_SAMPLES is not None and n_samples > MAX_SAMPLES:
                print(f"    Subsampling to {MAX_SAMPLES} samples...")
                np.random.seed(seed)
                indices = np.random.choice(n_samples, MAX_SAMPLES, replace=False)
                X = X[indices]
                Y = Y[indices]

            # Split (with run-specific seed)
            X_train, X_cal, X_test, Y_train, Y_cal, Y_test = split_train_cal_test(
                X,
                Y,
                test_size=test_size[dataset],
                cal_size=cal_size[dataset],
                random_state=seed,
            )

            # Standardize X (same as GCQR4)
            if dataset != "sgemm":
                scaler_X = StandardScaler()
                X_train = scaler_X.fit_transform(X_train).astype(np.float32)
                X_cal = scaler_X.transform(X_cal).astype(np.float32)
                X_test = scaler_X.transform(X_test).astype(np.float32)

            # TR transformation (whitening) or standard scaling
            tr = None
            tr_cond = None

            if tr_retr_bool:
                if tr_mode == "stabilized":
                    print(f"    Using STABILIZED CONDITIONAL TR (k={tr_knn_k}, shrinkage={tr_shrinkage_alpha})...")
                    tr_cond, Y_train = fit_stabilized_conditional_tr_and_standardize(
                        X_train,
                        Y_train,
                        method="knn",
                        k=tr_knn_k,
                        ridge=ridge,
                        shrinkage_alpha=tr_shrinkage_alpha,
                        fallback_threshold=tr_fallback_threshold,
                        eigenvalue_shrinkage=0.0,
                    )
                    Y_cal = conditional_tr_transform(X_cal, Y_cal, tr_cond)
                    Y_test = conditional_tr_transform(X_test, Y_test, tr_cond)

                elif tr_mode == "conditional":
                    print(f"    Using CONDITIONAL TR (KNN, k={tr_knn_k})...")
                    tr_cond, Y_train = fit_conditional_tr_and_standardize(
                        X_train,
                        Y_train,
                        method="knn",
                        k=tr_knn_k,
                        ridge=ridge,
                    )
                    Y_cal = conditional_tr_transform(X_cal, Y_cal, tr_cond)
                    Y_test = conditional_tr_transform(X_test, Y_test, tr_cond)

                else:
                    # Global TR
                    print(f"    Using GLOBAL TR...")
                    tr, Y_train = fit_tr_and_standardize(Y_train, ridge=ridge)
                    Y_cal = tr_transform(Y_cal, tr)
                    Y_test = tr_transform(Y_test, tr)
            else:
                scaler_Y = StandardScaler()
                Y_train = scaler_Y.fit_transform(Y_train).astype(np.float32)
                Y_cal = scaler_Y.transform(Y_cal).astype(np.float32)
                Y_test = scaler_Y.transform(Y_test).astype(np.float32)

            # Additional scaling by max norm (same as GCQR_3.py)
            norm_max_train = np.linalg.norm(Y_train, axis=1).max()
            scale_factor = 1.0 / norm_max_train
            Y_train = Y_train * scale_factor
            Y_cal = Y_cal * scale_factor
            Y_test = Y_test * scale_factor

            print(f"    Shapes: train={X_train.shape}, cal={X_cal.shape}, test={X_test.shape}")

            # Run VSPS evaluation
            vsps_results = run_vsps_evaluation(
                X_train,
                Y_train,
                X_cal,
                Y_cal,
                X_test,
                Y_test,
                alpha=alpha,
                dataset_name=dataset,
                tr=tr,
                tr_cond=tr_cond,
                norm_max_train=norm_max_train,
                scale_factor=scale_factor,
                verbose=False,  # Less verbose for multi-run
            )

            run_time = time.time() - run_start_time

            # Add run metadata
            vsps_results['run'] = run_idx + 1
            vsps_results['seed'] = seed
            vsps_results['time'] = run_time

            dataset_results.append(vsps_results)

            # Print summary for this run
            print(f"    Run {run_idx + 1} complete: coverage={vsps_results['coverage']:.4f}, "
                  f"wsc={vsps_results['wsc']:.4f}, time={run_time:.1f}s")

        # Save results for this dataset (one CSV per dataset)
        save_results_to_csv(
            dataset_results,
            output_dir="results_vsps",
            filename=f"results_vsps_{dataset}_{mode}.csv"
        )

        # Print dataset summary
        print(f"\n  === {dataset} Summary ({n_runs} runs) ===")
        coverages = [r['coverage'] for r in dataset_results]
        wscs = [r['wsc'] for r in dataset_results]
        vols = [r['log_volume_normalized_scaled'] for r in dataset_results]
        print(f"    Coverage: {np.mean(coverages):.4f} ± {np.std(coverages):.4f}")
        print(f"    WSC:      {np.nanmean(wscs):.4f} ± {np.nanstd(wscs):.4f}")
        print(f"    Log-vol:  {np.nanmean(vols):.4f} ± {np.nanstd(vols):.4f}")

    # Final summary
    print("\n" + "=" * 80)
    print(" BENCHMARK COMPLETE")
    print("=" * 80)
    print(f"  Datasets evaluated: {len(datasets)}")
    print(f"  Runs per dataset: {n_runs}")
    print(f"  Results saved to: results_vsps/results_vsps_<dataset>_{mode}.csv")
    print("=" * 80 + "\n")
