

from __future__ import annotations

import os
import time
import warnings
import numpy as np
import pandas as pd
from datetime import datetime
from typing import Dict, List, Optional

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)

from sklearn.preprocessing import StandardScaler

# Local imports
from GCQR_3 import run_gcqr3_pipeline
from config import get_train_config_gcqr
from train_cqr_v2_utils import split_train_cal_test
from load_dataset_grinsztajn import load_dataset

# TR preprocessing
from tr_retr import fit_tr_and_standardize, tr_transform
from tr_retr_conditional import (
    fit_stabilized_conditional_tr_and_standardize,
    conditional_tr_transform,
)


# =============================================================================
# CONFIGURATION
# =============================================================================
# Dataset/width/depth configurations to evaluate (dataset, width, depth)
CONFIGS = [
    ("scm20d", 16, 2),
    ("scm1d", 32, 4),
    ("rf1", 32, 4),
    ("rf2", 16, 2),
    ("sgemm", 144, 5),
]

# Number of runs per config
N_RUNS = 10

# Base seed and derived seeds (pattern from VSPS/OT_CP)
BASE_SEED = 42
SEEDS = [BASE_SEED + i * 111 for i in range(N_RUNS)]  # [42, 153, 264, 375, 486, ...]

# Alpha (miscoverage level)
ALPHA = 0.1

# TR configuration
TR_RETR_BOOL = True
TR_MODE = "stabilized"  # "conditional", "stabilized", or "global"
TR_KNN_K = 100
TR_SHRINKAGE_ALPHA = 0.5
TR_RIDGE = 1e-3
TR_EIGENVALUE_SHRINKAGE = 0.0  # Default from tr_retr_conditional
TR_FALLBACK_THRESHOLD = 1e6    # Default from tr_retr_conditional

# Data split ratios
TEST_SIZE = {"scm20d": 0.2, "sgemm": 0.2, "rf1": 0.2, "rf2": 0.2, "scm1d": 0.2}
CAL_SIZE = {"scm20d": 0.4, "sgemm": 0.4, "rf1": 0.4, "rf2": 0.4, "scm1d": 0.4}

# mu (convexity parameter)
MU_CONFIG = {"scm20d": 0.005, "sgemm": 0.005, "rf1": 0.005, "rf2": 0.005, "scm1d": 0.005}

# Max samples (subsampling)
MAX_SAMPLES = 25000

# Volume correction method
VOLUME_CORRECTION_METHOD = "jacobian"

# Output directory
OUTPUT_DIR = "results_gcqr"


# =============================================================================
# HELPER FUNCTIONS
# =============================================================================

def save_results_to_csv(
    results: List[Dict],
    output_dir: str,
    filename: str,
) -> str:
    """Save results to CSV file (minimal columns for multi-run analysis)."""
    os.makedirs(output_dir, exist_ok=True)
    filepath = os.path.join(output_dir, filename)
    
    df = pd.DataFrame(results)
    
    # Minimal columns for multi-run analysis (aligned with VSPS/OT_CP)
    keep_cols = [
        'run', 'seed', 'coverage', 'wsc_min', 
        'norm_log_vol_scaled', 'norm_log_vol_transformed',
        'margin', 'inside_ratio', 'time'
    ]
    df = df[[c for c in keep_cols if c in df.columns]]
    
    df.to_csv(filepath, index=False, float_format='%.6f')
    print(f"  ✓ Results saved to: {filepath}")
    
    return filepath


def print_dataset_summary(dataset_results: List[Dict], dataset_name: str, n_runs: int):
    """Print summary statistics for a dataset after all runs."""
    print(f"\n  {'='*60}")
    print(f"  {dataset_name} Summary ({n_runs} runs)")
    print(f"  {'='*60}")
    
    coverages = [r['coverage'] for r in dataset_results]
    wscs = [r.get('wsc_min', np.nan) for r in dataset_results]
    vols = [r.get('norm_log_vol_scaled', np.nan) for r in dataset_results]
    times = [r.get('time', np.nan) for r in dataset_results]
    
    print(f"    Coverage:     {np.mean(coverages):.4f} ± {np.std(coverages):.4f}")
    print(f"    WSC:          {np.nanmean(wscs):.4f} ± {np.nanstd(wscs):.4f}")
    print(f"    Log-vol:      {np.nanmean(vols):.4f} ± {np.nanstd(vols):.4f}")
    print(f"    Time per run: {np.mean(times):.1f}s ± {np.std(times):.1f}s")


# =============================================================================
# MAIN
# =============================================================================

def main():
    mode = "tr_retr" if TR_RETR_BOOL else "standard"
    
    print("=" * 70)
    print(" GCQR3 Multi-Run Benchmark (Multiple Configs)")
    print("=" * 70)
    print(f"  Configs: {len(CONFIGS)} (dataset, width, depth) combinations")
    print(f"  N_RUNS:  {N_RUNS}")
    print(f"  Seeds:   {SEEDS}")
    print(f"  Alpha:   {ALPHA}")
    print(f"  TR mode: {TR_MODE if TR_RETR_BOOL else 'none'}")
    print(f"  Output:  {OUTPUT_DIR}/")
    print("=" * 70)
    
    total_start_time = time.time()
    all_summaries = []
    
    # Cache loaded datasets to avoid reloading
    dataset_cache = {}
    
    for config_idx, (dataset, cfg_width, cfg_depth) in enumerate(CONFIGS):
        config_name = f"{dataset}_w{cfg_width}_d{cfg_depth}"
        
        print(f"\n{'#'*70}")
        print(f" CONFIG {config_idx + 1}/{len(CONFIGS)}: {dataset.upper()} (width={cfg_width}, depth={cfg_depth})")
        print(f"{'#'*70}")
        
        config_results = []
        
        # Load dataset (use cache if already loaded)
        if dataset not in dataset_cache:
            print(f"\n  Loading dataset: {dataset}")
            if dataset == "rf1":
                X_full, Y_full = load_dataset(dataset, use_rf1_resplit=True, random_state=BASE_SEED)
            else:
                X_full, Y_full = load_dataset(dataset, random_state=BASE_SEED)
            
            X_full = X_full.values.astype(np.float32) if hasattr(X_full, "values") else np.asarray(X_full, dtype=np.float32)
            Y_full = np.asarray(Y_full, dtype=np.float32)
            dataset_cache[dataset] = (X_full, Y_full)
            print(f"  Full dataset: {X_full.shape[0]} samples, X: {X_full.shape[1]} features, Y: {Y_full.shape[1]} outputs")
        else:
            X_full, Y_full = dataset_cache[dataset]
            print(f"\n  Using cached dataset: {dataset}")
        
        # Run loop
        for run_idx, seed in enumerate(SEEDS):
            print(f"\n  --- Run {run_idx + 1}/{N_RUNS} (seed={seed}) ---")
            run_start_time = time.time()
            
            # Copy data for this run
            X = X_full.copy()
            Y = Y_full.copy()
            
            # Subsample with run-specific seed
            n_samples = X.shape[0]
            if MAX_SAMPLES is not None and n_samples > MAX_SAMPLES:
                print(f"    Subsampling to {MAX_SAMPLES} samples...")
                np.random.seed(seed)
                indices = np.random.choice(n_samples, MAX_SAMPLES, replace=False)
                X = X[indices]
                Y = Y[indices]
            
            # Split with run-specific seed
            X_train, X_cal, X_test, Y_train, Y_cal, Y_test = split_train_cal_test(
                X, Y,
                test_size=TEST_SIZE[dataset],
                cal_size=CAL_SIZE[dataset],
                random_state=seed,
            )
            
            print(f"    Split: train={len(X_train)}, cal={len(X_cal)}, test={len(X_test)}")
            
            # Scale X
            if dataset != "sgemm":
                scaler_X = StandardScaler()
                X_train = scaler_X.fit_transform(X_train).astype(np.float32)
                X_cal = scaler_X.transform(X_cal).astype(np.float32)
                X_test = scaler_X.transform(X_test).astype(np.float32)
            
            # TR transformation
            tr = None
            tr_cond = None
            
            if TR_RETR_BOOL:
                if TR_MODE == "stabilized":
                    tr_cond, Y_train = fit_stabilized_conditional_tr_and_standardize(
                        X_train, Y_train,
                        method="knn",
                        k=TR_KNN_K,
                        ridge=TR_RIDGE,
                        shrinkage_alpha=TR_SHRINKAGE_ALPHA,
                        eigenvalue_shrinkage=TR_EIGENVALUE_SHRINKAGE,
                        fallback_threshold=TR_FALLBACK_THRESHOLD,
                    )
                    Y_cal = conditional_tr_transform(X_cal, Y_cal, tr_cond)
                    Y_test = conditional_tr_transform(X_test, Y_test, tr_cond)
                    
                elif TR_MODE == "global":
                    tr, Y_train = fit_tr_and_standardize(Y_train, ridge=TR_RIDGE)
                    Y_cal = tr_transform(Y_cal, tr)
                    Y_test = tr_transform(Y_test, tr)
            else:
                scaler_Y = StandardScaler()
                Y_train = scaler_Y.fit_transform(Y_train).astype(np.float32)
                Y_cal = scaler_Y.transform(Y_cal).astype(np.float32)
                Y_test = scaler_Y.transform(Y_test).astype(np.float32)
            
            # Scale by max norm
            norm_max_train = float(np.linalg.norm(Y_train, axis=1).max())
            scale_factor = 1.0 / norm_max_train
            Y_train = Y_train * scale_factor
            Y_cal = Y_cal * scale_factor
            Y_test = Y_test * scale_factor
            
            # Get training config
            cfg = get_train_config_gcqr(dataset)
            
            # Run GCQR3 pipeline with specified width/depth
            try:
                results = run_gcqr3_pipeline(
                    X_train, Y_train,
                    X_cal, Y_cal,
                    X_test, Y_test,
                    alpha=ALPHA,
                    width=cfg_width,
                    depth=cfg_depth,
                    mu=MU_CONFIG.get(dataset, 0.005),
                    cfg=cfg,
                    dataset_name=dataset,
                    tr=tr,
                    tr_cond=tr_cond,
                    volume_correction_method=VOLUME_CORRECTION_METHOD,
                    n_vol_samples=50,
                )
                
                # Add run metadata
                results['run'] = run_idx + 1
                results['seed'] = seed
                results['time'] = time.time() - run_start_time
                results['dataset'] = dataset
                results['width'] = cfg_width
                results['depth'] = cfg_depth
                
                # Compute norm_log_vol_scaled (volume in original scale)
                if 'norm_log_vol_transformed' in results and not np.isnan(results['norm_log_vol_transformed']):
                    results['norm_log_vol_scaled'] = results['norm_log_vol_transformed'] + np.log(norm_max_train)
                else:
                    results['norm_log_vol_scaled'] = np.nan
                
                print(f"    Coverage: {results['coverage']:.4f}, WSC: {results.get('wsc_min', np.nan):.4f}, "
                      f"Log-vol: {results.get('norm_log_vol_scaled', np.nan):.4f}, Time: {results['time']:.1f}s")
                
            except Exception as e:
                print(f"    ERROR: {e}")
                results = {
                    'run': run_idx + 1,
                    'seed': seed,
                    'time': time.time() - run_start_time,
                    'dataset': dataset,
                    'width': cfg_width,
                    'depth': cfg_depth,
                    'coverage': np.nan,
                    'wsc_min': np.nan,
                    'norm_log_vol_scaled': np.nan,
                    'norm_log_vol_transformed': np.nan,
                    'margin': np.nan,
                    'inside_ratio': np.nan,
                    'error': str(e),
                }
            
            config_results.append(results)
        
        # Save results for this config
        filename = f"results_gcqr3_{config_name}_{mode}.csv"
        save_results_to_csv(config_results, OUTPUT_DIR, filename)
        
        # Print config summary
        print_dataset_summary(config_results, config_name, N_RUNS)
        
        # Store summary for final table
        coverages = [r['coverage'] for r in config_results if not np.isnan(r.get('coverage', np.nan))]
        wscs = [r['wsc_min'] for r in config_results if not np.isnan(r.get('wsc_min', np.nan))]
        vols = [r['norm_log_vol_scaled'] for r in config_results if not np.isnan(r.get('norm_log_vol_scaled', np.nan))]
        
        all_summaries.append({
            'config': config_name,
            'dataset': dataset,
            'width': cfg_width,
            'depth': cfg_depth,
            'coverage_mean': np.mean(coverages) if coverages else np.nan,
            'coverage_std': np.std(coverages) if coverages else np.nan,
            'wsc_mean': np.mean(wscs) if wscs else np.nan,
            'wsc_std': np.std(wscs) if wscs else np.nan,
            'log_vol_mean': np.mean(vols) if vols else np.nan,
            'log_vol_std': np.std(vols) if vols else np.nan,
        })
    
    # Final summary
    total_time = time.time() - total_start_time
    
    print("\n" + "=" * 70)
    print(" BENCHMARK COMPLETE")
    print("=" * 70)
    print(f"  Total time: {total_time:.1f}s ({total_time/60:.1f} min)")
    print(f"  Results saved to: {OUTPUT_DIR}/")
    
    print("\n" + "=" * 80)
    print(" SUMMARY TABLE (mean ± std)")
    print("=" * 80)
    
    df_summary = pd.DataFrame(all_summaries)
    print(f"\n{'Config':<25} | {'Coverage':<15} | {'WSC':<15} | {'Log-Vol':<15}")
    print("-" * 80)
    for _, row in df_summary.iterrows():
        print(f"{row['config']:<25} | "
              f"{row['coverage_mean']:.4f} ± {row['coverage_std']:.4f} | "
              f"{row['wsc_mean']:.4f} ± {row['wsc_std']:.4f} | "
              f"{row['log_vol_mean']:.4f} ± {row['log_vol_std']:.4f}")
    
    # Save summary table
    summary_path = os.path.join(OUTPUT_DIR, f"summary_gcqr3_multiconfig_{mode}.csv")
    df_summary.to_csv(summary_path, index=False, float_format='%.6f')
    print(f"\n  ✓ Summary saved to: {summary_path}")
    
    print("=" * 70)


if __name__ == "__main__":
    main()
