"""
Real-World Dataset Evaluation Script
=====================================

Runs Stable-QDA evaluation on real-world datasets and saves results to JSON.

Usage:
    python run_evaluation.py --dataset htru2 --output results/
    python run_evaluation.py --dataset creditcard --subsample 50000
    python run_evaluation.py --all --output results/
"""

import numpy as np
import pandas as pd
import json
import argparse
import os
from datetime import datetime
from pathlib import Path
from scipy import stats
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import (
    accuracy_score, f1_score, confusion_matrix,
    precision_recall_curve, auc, roc_auc_score
)
import warnings

import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src'))

from stable_qda import StableQDA, GaussianQDA
from alpha_estimation import estimate_alpha


# =============================================================================
# Dataset Configuration
# =============================================================================

DATASET_CONFIGS = {
    'htru2': {
        'name': 'HTRU2 Pulsar Detection',
        'url': 'https://archive.ics.uci.edu/ml/datasets/HTRU2',
        'target': 'class',
    },
    'creditcard': {
        'name': 'Credit Card Fraud Detection',
        'url': 'https://www.kaggle.com/mlg-ulb/creditcardfraud',
        'target': 'Class',
        'default_subsample': 50000,
    },
    'ionosphere': {
        'name': 'Ionosphere Radar Returns',
        'url': 'https://archive.ics.uci.edu/ml/datasets/ionosphere',
        'target': 'class',
    },
    'weekly': {
        'name': 'Weekly Stock Returns (ISLR)',
        'url': 'https://www.statlearning.com/',
        'target': 'Direction',
    },
}


# =============================================================================
# Data Loading
# =============================================================================

def load_dataset(name: str, data_dir: str = 'data/') -> tuple:
    """Load dataset from CSV file."""
    config = DATASET_CONFIGS.get(name.lower())
    if config is None:
        raise ValueError(f"Unknown dataset: {name}. Available: {list(DATASET_CONFIGS.keys())}")
    
    filepath = os.path.join(data_dir, f"{name.lower()}.csv")
    
    if not os.path.exists(filepath):
        raise FileNotFoundError(
            f"Dataset not found: {filepath}\n"
            f"Download from: {config['url']}"
        )
    
    df = pd.read_csv(filepath)
    
    # Find target column
    target_col = None
    for col in df.columns:
        if col.lower() == config['target'].lower():
            target_col = col
            break
    
    if target_col is None:
        raise ValueError(f"Target column '{config['target']}' not found")
    
    y = df[target_col].values
    X = df.drop(columns=[target_col]).values
    
    # Encode labels
    le = LabelEncoder()
    y = le.fit_transform(y)
    
    return X, y, config


def preprocess_data(X, y, scale=True, subsample=None):
    """Apply preprocessing."""
    log = []
    
    if subsample and subsample < len(y):
        X, _, y, _ = train_test_split(X, y, train_size=subsample, stratify=y, random_state=42)
        log.append(f'Subsampled to n={subsample}')
    
    if scale:
        scaler = StandardScaler()
        X = scaler.fit_transform(X)
        log.append('StandardScaler')
    
    return X, y, log


# =============================================================================
# Diagnostic
# =============================================================================

def run_diagnostic(X, y):
    """Run dataset diagnostic."""
    n, d = X.shape
    classes = np.unique(y)
    
    diagnostic = {'n': n, 'd': d, 'n_classes': len(classes), 'classes': {}}
    
    traces, dets, alphas = {}, {}, {}
    outlier_rates, mean_median_shifts = {}, {}
    
    for c in classes:
        X_c = X[y == c]
        alphas[c] = estimate_alpha(X_c)
        
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            cov = np.cov(X_c.T)
            if cov.ndim == 0:
                cov = np.array([[cov]])
        
        traces[c] = float(np.trace(cov))
        dets[c] = float(np.linalg.det(cov)) if d <= 50 else np.nan
        
        # Outlier rate
        try:
            cov_inv = np.linalg.inv(cov + 1e-8 * np.eye(d))
            X_centered = X_c - np.mean(X_c, axis=0)
            mahal_sq = np.sum((X_centered @ cov_inv) * X_centered, axis=1)
            threshold = stats.chi2.ppf(0.95, d)
            outlier_rates[c] = float(np.mean(mahal_sq > threshold) * 100)
        except:
            outlier_rates[c] = np.nan
        
        # Mean-median shift
        mean_c, median_c = np.mean(X_c, axis=0), np.median(X_c, axis=0)
        denom = np.linalg.norm(mean_c) + np.linalg.norm(median_c) + 1e-10
        mean_median_shifts[c] = float(np.linalg.norm(mean_c - median_c) / denom)
        
        diagnostic['classes'][int(c)] = {
            'n': int(len(X_c)),
            'alpha': float(alphas[c]),
            'trace': traces[c],
            'det': dets[c] if np.isfinite(dets[c]) else None,
            'outlier_rate': outlier_rates[c] if np.isfinite(outlier_rates[c]) else None,
            'mean_median_shift': mean_median_shifts[c]
        }
    
    # Aggregate
    c0, c1 = classes[0], classes[1]
    scale_ratio = max(traces[c0], traces[c1]) / max(min(traces[c0], traces[c1]), 1e-10)
    
    if np.isfinite(dets[c0]) and np.isfinite(dets[c1]) and min(dets[c0], dets[c1]) > 0:
        det_ratio = max(dets[c0], dets[c1]) / min(dets[c0], dets[c1])
    else:
        det_ratio = np.nan
    
    avg_alpha = np.mean([alphas[c] for c in classes])
    avg_outlier = np.nanmean([outlier_rates[c] for c in classes])
    avg_shift = np.mean([mean_median_shifts[c] for c in classes])
    
    # Heavy-tail signals
    signals = int(avg_alpha < 1.8) + int(avg_outlier > 7) + int(avg_shift > 0.2)
    likely_heavy = signals >= 2
    
    # Tyler threshold
    if np.isfinite(det_ratio):
        if det_ratio < 10: tyler_thresh = 2.0
        elif det_ratio < 50: tyler_thresh = 1.9
        elif det_ratio < 100: tyler_thresh = 1.8
        elif det_ratio < 1000: tyler_thresh = 1.7
        else: tyler_thresh = 1.6
    else:
        tyler_thresh = 1.8 if scale_ratio < 2 else 1.7
    
    tyler_safe = avg_alpha < tyler_thresh
    
    # Recommendation
    if not likely_heavy and avg_alpha > 1.7:
        rec = 'gaussian'
    elif tyler_safe or avg_alpha < 1.5:
        rec = 'robust'
    elif likely_heavy:
        rec = 'standard'
    else:
        rec = 'gaussian'
    
    diagnostic['summary'] = {
        'avg_alpha': float(avg_alpha),
        'scale_ratio': float(scale_ratio),
        'det_ratio': float(det_ratio) if np.isfinite(det_ratio) else None,
        'avg_outlier_rate': float(avg_outlier) if np.isfinite(avg_outlier) else None,
        'avg_mean_median_shift': float(avg_shift),
        'heavy_tail_signals': signals,
        'likely_heavy_tailed': likely_heavy,
        'tyler_threshold': float(tyler_thresh),
        'tyler_is_safe': tyler_safe,
        'recommendation': rec
    }
    
    return diagnostic


# =============================================================================
# Evaluation
# =============================================================================

def compute_metrics(y_true, y_pred, y_proba):
    """Compute classification metrics."""
    metrics = {
        'accuracy': float(accuracy_score(y_true, y_pred)),
        'f1': float(f1_score(y_true, y_pred, zero_division=0))
    }
    
    cm = confusion_matrix(y_true, y_pred)
    if cm.shape == (2, 2):
        tn, fp, fn, tp = cm.ravel()
        metrics['tpr'] = float(tp / (tp + fn)) if (tp + fn) > 0 else 0.0
        metrics['tnr'] = float(tn / (tn + fp)) if (tn + fp) > 0 else 0.0
    
    if y_proba is not None:
        try:
            prec, rec, _ = precision_recall_curve(y_true, y_proba)
            metrics['pr_auc'] = float(auc(rec, prec))
            metrics['roc_auc'] = float(roc_auc_score(y_true, y_proba))
            
            # Recall @ Precision >= 95%
            r_at_p95 = max([r for p, r in zip(prec, rec) if p >= 0.95], default=0.0)
            metrics['recall_at_prec95'] = float(r_at_p95)
        except:
            metrics['pr_auc'] = None
            metrics['roc_auc'] = None
            metrics['recall_at_prec95'] = None
    
    return metrics


def run_cv_evaluation(X, y, n_splits=5, alpha=1.5):
    """Run stratified cross-validation."""
    
    models = {
        'gaussian': lambda: GaussianQDA(),
        'stable_standard': lambda: StableQDA(alpha=alpha, estimator='standard'),
        'stable_robust': lambda: StableQDA(alpha=alpha, estimator='robust'),
    }
    
    fold_results = {name: [] for name in models}
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
    
    for fold, (train_idx, test_idx) in enumerate(skf.split(X, y)):
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]
        
        for name, model_fn in models.items():
            clf = model_fn()
            clf.fit(X_train, y_train)
            y_pred = clf.predict(X_test)
            
            y_proba = None
            if hasattr(clf, 'predict_proba'):
                proba = clf.predict_proba(X_test)
                y_proba = proba[:, 1] if proba.shape[1] > 1 else proba[:, 0]
            
            metrics = compute_metrics(y_test, y_pred, y_proba)
            metrics['fold'] = fold
            fold_results[name].append(metrics)
    
    # Aggregate results
    results = {}
    for name, folds in fold_results.items():
        results[name] = {'per_fold': folds, 'mean': {}, 'std': {}}
        
        metric_names = [k for k in folds[0].keys() if k != 'fold']
        for metric in metric_names:
            values = [f[metric] for f in folds if f[metric] is not None]
            if values:
                results[name]['mean'][metric] = float(np.mean(values))
                results[name]['std'][metric] = float(np.std(values))
    
    # Paired t-tests vs Gaussian
    gauss_folds = fold_results['gaussian']
    for name in models:
        if name == 'gaussian':
            continue
        
        results[name]['p_values'] = {}
        for metric in ['accuracy', 'pr_auc']:
            g_vals = [f[metric] for f in gauss_folds if f.get(metric) is not None]
            s_vals = [f[metric] for f in fold_results[name] if f.get(metric) is not None]
            
            if len(g_vals) == len(s_vals) and len(g_vals) > 1:
                _, p = stats.ttest_rel(s_vals, g_vals)
                results[name]['p_values'][metric] = float(p)
    
    # Compute deltas vs Gaussian
    g_acc = results['gaussian']['mean']['accuracy']
    g_pr = results['gaussian']['mean'].get('pr_auc', 0) or 0
    
    for name in models:
        if name == 'gaussian':
            continue
        
        s_acc = results[name]['mean']['accuracy']
        s_pr = results[name]['mean'].get('pr_auc', 0) or 0
        
        results[name]['delta'] = {
            'accuracy_pct': float((s_acc - g_acc) * 100),
            'pr_auc_pct': float((s_pr - g_pr) / g_pr * 100) if g_pr > 0 else None,
            'error_reduction_pct': float((g_acc - s_acc) / (1 - g_acc) * -100) if g_acc < 1 else 0
        }
    
    return results


# =============================================================================
# Main
# =============================================================================

def evaluate_dataset(
    dataset_name: str,
    data_dir: str = 'data/',
    output_dir: str = 'results/',
    subsample: int = None,
    n_splits: int = 5,
    alpha: float = 1.5,
    verbose: bool = True
):
    """Evaluate Stable-QDA on a single dataset."""
    
    if verbose:
        print("=" * 70)
        print(f"EVALUATING: {dataset_name.upper()}")
        print("=" * 70)
    
    # Load data
    X, y, config = load_dataset(dataset_name, data_dir)
    
    if verbose:
        print(f"\nDataset: {config['name']}")
        print(f"Original: n={X.shape[0]}, d={X.shape[1]}")
        print(f"Classes: {dict(zip(*np.unique(y, return_counts=True)))}")
    
    # Preprocess
    default_sub = config.get('default_subsample')
    actual_sub = subsample if subsample else default_sub
    X, y, preprocess_log = preprocess_data(X, y, scale=True, subsample=actual_sub)
    
    if verbose and preprocess_log:
        print(f"Preprocessing: {', '.join(preprocess_log)}")
        print(f"Final: n={X.shape[0]}, d={X.shape[1]}")
    
    # Diagnostic
    if verbose:
        print("\n" + "-" * 70)
        print("DIAGNOSTIC")
        print("-" * 70)
    
    diagnostic = run_diagnostic(X, y)
    
    if verbose:
        s = diagnostic['summary']
        print(f"Average α: {s['avg_alpha']:.2f}")
        print(f"Scale ratio: {s['scale_ratio']:.2f}")
        print(f"Det ratio: {s['det_ratio']:.1f}" if s['det_ratio'] else "Det ratio: N/A")
        print(f"Heavy-tail signals: {s['heavy_tail_signals']}/3")
        print(f"Recommendation: {s['recommendation']}")
    
    # Evaluation
    if verbose:
        print("\n" + "-" * 70)
        print(f"{n_splits}-FOLD CROSS-VALIDATION")
        print("-" * 70)
    
    results = run_cv_evaluation(X, y, n_splits=n_splits, alpha=alpha)
    
    if verbose:
        print(f"\n{'Method':<20} {'Accuracy':<12} {'PR-AUC':<12} {'Δ Acc':<10}")
        print("-" * 54)
        for name in results:
            acc = results[name]['mean']['accuracy']
            pr = results[name]['mean'].get('pr_auc', 0) or 0
            delta = results[name].get('delta', {}).get('accuracy_pct', 0)
            print(f"{name:<20} {acc:.4f}       {pr:.4f}       {delta:+.2f}%")
    
    # Prepare output
    output = {
        'dataset': dataset_name,
        'config_name': config['name'],
        'timestamp': datetime.now().isoformat(),
        'settings': {
            'n': int(X.shape[0]),
            'd': int(X.shape[1]),
            'preprocessing': preprocess_log,
            'n_splits': n_splits,
            'alpha': alpha,
        },
        'class_distribution': {int(k): int(v) for k, v in zip(*np.unique(y, return_counts=True))},
        'diagnostic': diagnostic,
        'results': results,
    }
    
    # Find best method
    best = max(
        [k for k in results if k != 'gaussian'],
        key=lambda k: results[k]['mean'].get('pr_auc', 0) or 0
    )
    output['best_method'] = best
    
    # Save
    os.makedirs(output_dir, exist_ok=True)
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    output_path = os.path.join(output_dir, f'results_{dataset_name}_{timestamp}.json')
    
    with open(output_path, 'w') as f:
        json.dump(output, f, indent=2)
    
    if verbose:
        print(f"\nResults saved: {output_path}")
    
    return output


def main():
    parser = argparse.ArgumentParser(description='Stable-QDA Real-World Evaluation')
    parser.add_argument('--dataset', type=str, help='Dataset name')
    parser.add_argument('--all', action='store_true', help='Run all datasets')
    parser.add_argument('--data_dir', type=str, default='data/', help='Data directory')
    parser.add_argument('--output', type=str, default='results/', help='Output directory')
    parser.add_argument('--subsample', type=int, default=None, help='Subsample size')
    parser.add_argument('--n_splits', type=int, default=5, help='CV folds')
    parser.add_argument('--alpha', type=float, default=1.5, help='Fixed alpha')
    args = parser.parse_args()
    
    if args.all:
        datasets = list(DATASET_CONFIGS.keys())
    elif args.dataset:
        datasets = [args.dataset]
    else:
        parser.print_help()
        return
    
    all_results = {}
    for dataset in datasets:
        try:
            result = evaluate_dataset(
                dataset,
                data_dir=args.data_dir,
                output_dir=args.output,
                subsample=args.subsample,
                n_splits=args.n_splits,
                alpha=args.alpha,
            )
            all_results[dataset] = result
        except Exception as e:
            print(f"Error evaluating {dataset}: {e}")
    
    # Summary
    if len(all_results) > 1:
        print("\n" + "=" * 70)
        print("SUMMARY ACROSS DATASETS")
        print("=" * 70)
        
        print(f"\n{'Dataset':<15} {'Best Method':<20} {'Δ Acc vs Gaussian':<15}")
        print("-" * 50)
        for name, result in all_results.items():
            best = result['best_method']
            delta = result['results'][best]['delta']['accuracy_pct']
            print(f"{name:<15} {best:<20} {delta:+.2f}%")


if __name__ == '__main__':
    main()
