"""
Dataset Diagnostic Script
=========================

Run this on your dataset to understand its characteristics and get
estimator recommendations.

Usage:
    python diagnose_dataset.py --data your_data.csv --target label_column
    python diagnose_dataset.py --data your_data.csv --target label --scale
    python diagnose_dataset.py --demo

Output:
    - Tail index (α) estimation per class
    - Scale differences between classes
    - Heavy-tail signal detection
    - Estimator recommendation
"""

import numpy as np
import argparse
import warnings
from scipy import stats

import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src'))

from alpha_estimation import estimate_alpha


def tyler_m_estimator(X, max_iter=100, tol=1e-6):
    """Compute Tyler's M-estimator of scatter."""
    n, d = X.shape
    mu = np.median(X, axis=0)
    X_centered = X - mu
    Sigma = np.eye(d)
    
    for _ in range(max_iter):
        Sigma_inv = np.linalg.inv(Sigma + 1e-8 * np.eye(d))
        mahal_sq = np.sum((X_centered @ Sigma_inv) * X_centered, axis=1)
        mahal_sq = np.maximum(mahal_sq, 1e-10)
        weights = d / mahal_sq
        Sigma_new = (X_centered.T * weights) @ X_centered / n
        Sigma_new = Sigma_new / np.trace(Sigma_new) * d
        
        if np.linalg.norm(Sigma_new - Sigma, 'fro') < tol:
            break
        Sigma = Sigma_new
    
    return Sigma


def diagnose_dataset(X, y):
    """
    Run diagnostic analysis on a dataset.
    
    Prints detailed report and returns summary dict.
    """
    n, d = X.shape
    classes = np.unique(y)
    n_classes = len(classes)
    
    print("=" * 70)
    print("DATASET DIAGNOSTIC REPORT")
    print("=" * 70)
    print(f"\nDataset: n={n}, d={d}, classes={n_classes}")
    
    report = {'n': n, 'd': d, 'n_classes': n_classes, 'classes': {}}
    
    # 1. Estimate α for each class
    print("\n" + "-" * 70)
    print("1. TAIL INDEX (α) ESTIMATION")
    print("-" * 70)
    print("   α < 1.5: Very heavy tails → Robust estimators likely help")
    print("   α ∈ [1.5, 1.8]: Moderate tails → May or may not need robust")
    print("   α > 1.8: Near-Gaussian → Standard estimators likely fine")
    print()
    
    alphas = {}
    for c in classes:
        X_c = X[y == c]
        alphas[c] = estimate_alpha(X_c)
        report['classes'][c] = {'alpha': alphas[c], 'n': len(X_c)}
        
        verdict = "HEAVY" if alphas[c] < 1.5 else "MODERATE" if alphas[c] < 1.8 else "LIGHT"
        print(f"   Class {c}: α ≈ {alphas[c]:.2f} ({verdict} tails), n={len(X_c)}")
    
    # 2. Check scale differences
    print("\n" + "-" * 70)
    print("2. SCALE DIFFERENCES BETWEEN CLASSES")
    print("-" * 70)
    print("   If trace(Σ) differs much between classes → Scale is discriminative")
    print("   Tyler normalizes trace=d → Loses this information!")
    print()
    
    traces, dets = {}, {}
    for c in classes:
        X_c = X[y == c]
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            cov = np.cov(X_c.T)
            if cov.ndim == 0:
                cov = np.array([[cov]])
        traces[c] = np.trace(cov)
        dets[c] = np.linalg.det(cov) if d <= 50 else np.nan
        report['classes'][c]['trace'] = traces[c]
        report['classes'][c]['det'] = dets[c]
        print(f"   Class {c}: trace(Σ) = {traces[c]:.2f}, det(Σ) = {dets[c]:.2e}")
    
    if n_classes == 2:
        c0, c1 = classes[0], classes[1]
        scale_ratio = max(traces[c0], traces[c1]) / min(traces[c0], traces[c1])
        report['scale_ratio'] = scale_ratio
        print(f"\n   Scale ratio: {scale_ratio:.2f}")
        if scale_ratio > 2:
            print("   ⚠ LARGE scale difference → Tyler may lose discriminative info!")
        else:
            print("   ✓ Similar scales → Tyler's normalization less harmful")
    
    # 3. Compare sample covariance vs Tyler
    print("\n" + "-" * 70)
    print("3. SAMPLE COVARIANCE vs TYLER'S M-ESTIMATOR")
    print("-" * 70)
    
    for c in classes:
        X_c = X[y == c]
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            sample_cov = np.cov(X_c.T)
            if sample_cov.ndim == 0:
                sample_cov = np.array([[sample_cov]])
        
        tyler_cov = tyler_m_estimator(X_c - np.median(X_c, axis=0))
        
        # Compare traces (Tyler normalizes to d, so scale the comparison)
        sample_trace = np.trace(sample_cov)
        tyler_trace = np.trace(tyler_cov)
        trace_diff = abs(sample_trace - tyler_trace) / max(sample_trace, tyler_trace)
        
        print(f"   Class {c}:")
        print(f"      Sample trace: {sample_trace:.2f}")
        print(f"      Tyler trace:  {tyler_trace:.2f} (normalized to d={d})")
        print(f"      Relative diff: {trace_diff*100:.1f}%")
    
    # 4. Outlier analysis
    print("\n" + "-" * 70)
    print("4. OUTLIER ANALYSIS (Mahalanobis Distance)")
    print("-" * 70)
    print("   Expected outlier rate (χ² 95%): 5%")
    print("   Rate > 7% suggests heavy tails")
    print()
    
    outlier_rates = {}
    for c in classes:
        X_c = X[y == c]
        mean_c = np.mean(X_c, axis=0)
        
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            cov_c = np.cov(X_c.T) + 1e-8 * np.eye(d)
        
        try:
            cov_inv = np.linalg.inv(cov_c)
            diff = X_c - mean_c
            mahal_sq = np.sum((diff @ cov_inv) * diff, axis=1)
            threshold = stats.chi2.ppf(0.95, d)
            outlier_rate = np.mean(mahal_sq > threshold) * 100
        except:
            outlier_rate = np.nan
        
        outlier_rates[c] = outlier_rate
        report['classes'][c]['outlier_rate'] = outlier_rate
        
        status = "⚠ ELEVATED" if outlier_rate > 7 else "✓ Normal"
        print(f"   Class {c}: {outlier_rate:.1f}% outliers {status}")
    
    # 5. Mean-median shift
    print("\n" + "-" * 70)
    print("5. MEAN-MEDIAN SHIFT")
    print("-" * 70)
    print("   Large shift (> 0.2) indicates asymmetry/outliers")
    print()
    
    mean_median_shifts = {}
    for c in classes:
        X_c = X[y == c]
        mean_c = np.mean(X_c, axis=0)
        median_c = np.median(X_c, axis=0)
        denom = np.linalg.norm(mean_c) + np.linalg.norm(median_c) + 1e-10
        shift = np.linalg.norm(mean_c - median_c) / denom
        mean_median_shifts[c] = shift
        report['classes'][c]['mean_median_shift'] = shift
        
        status = "⚠ LARGE" if shift > 0.2 else "✓ Small"
        print(f"   Class {c}: shift = {shift:.3f} {status}")
    
    # 6. Summary and recommendation
    print("\n" + "=" * 70)
    print("SUMMARY AND RECOMMENDATION")
    print("=" * 70)
    
    avg_alpha = np.mean([alphas[c] for c in classes])
    avg_outlier = np.nanmean([outlier_rates[c] for c in classes])
    avg_shift = np.mean([mean_median_shifts[c] for c in classes])
    
    # Heavy-tail signals
    signals = []
    if avg_alpha < 1.8:
        signals.append(f"α={avg_alpha:.2f} < 1.8")
    if avg_outlier > 7:
        signals.append(f"outlier rate={avg_outlier:.1f}% > 7%")
    if avg_shift > 0.2:
        signals.append(f"mean-median shift={avg_shift:.2f} > 0.2")
    
    likely_heavy = len(signals) >= 2
    
    # Tyler threshold
    if n_classes == 2:
        c0, c1 = classes[0], classes[1]
        if np.isfinite(dets[c0]) and np.isfinite(dets[c1]) and min(dets[c0], dets[c1]) > 0:
            det_ratio = max(dets[c0], dets[c1]) / min(dets[c0], dets[c1])
        else:
            det_ratio = np.nan
        
        if np.isfinite(det_ratio):
            if det_ratio < 10: tyler_thresh = 2.0
            elif det_ratio < 50: tyler_thresh = 1.9
            elif det_ratio < 100: tyler_thresh = 1.8
            elif det_ratio < 1000: tyler_thresh = 1.7
            else: tyler_thresh = 1.6
        else:
            tyler_thresh = 1.8 if scale_ratio < 2 else 1.7
    else:
        det_ratio = np.nan
        tyler_thresh = 1.8
    
    tyler_safe = avg_alpha < tyler_thresh
    det_str = f"{det_ratio:.1f}" if np.isfinite(det_ratio) else "N/A"
    
    print(f"""
   Summary statistics:
   - Average α across classes: {avg_alpha:.2f}
   - Scale ratio (trace): {scale_ratio:.2f}
   - Det ratio: {det_str}
   - Average outlier rate: {avg_outlier:.1f}% (expected 5%)
   - Average mean-median shift: {avg_shift:.3f}
   
   Heavy-tail signals: {len(signals)}/3 → {"Likely heavy-tailed" if likely_heavy else "Likely light-tailed"}
   {chr(10).join('     • ' + s for s in signals) if signals else '     (none)'}
   
   Tyler threshold: α < {tyler_thresh:.1f} (based on det ratio {det_str})
   Your α ≈ {avg_alpha:.2f} → Tyler {"IS SAFE ✓" if tyler_safe else "MAY HURT ⚠"}
    """)
    
    # Final recommendation
    if not likely_heavy and avg_alpha > 1.7:
        rec = "Gaussian QDA"
        reason = "Light tails, standard assumptions hold"
    elif tyler_safe or avg_alpha < 1.5:
        rec = "Stable-QDA with ROBUST estimators (spatial median + Tyler)"
        reason = "Heavy tails warrant robust estimation"
    elif likely_heavy:
        rec = "Stable-QDA with STANDARD estimators (mean + Ledoit-Wolf)"
        reason = "Moderate tails but scale differences are discriminative"
    else:
        rec = "Gaussian QDA"
        reason = "Near-Gaussian with heteroscedasticity"
    
    print(f"""
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
   RECOMMENDATION: {rec}
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
   
   Reason: {reason}
    """)
    
    report['summary'] = {
        'avg_alpha': avg_alpha,
        'scale_ratio': scale_ratio,
        'det_ratio': det_ratio if np.isfinite(det_ratio) else None,
        'avg_outlier_rate': avg_outlier,
        'avg_mean_median_shift': avg_shift,
        'heavy_tail_signals': len(signals),
        'likely_heavy_tailed': likely_heavy,
        'tyler_threshold': tyler_thresh,
        'tyler_is_safe': tyler_safe,
        'recommendation': rec,
    }
    
    return report


def main():
    parser = argparse.ArgumentParser(description="Diagnose dataset for Stable-QDA")
    parser.add_argument('--data', type=str, help='Path to CSV file')
    parser.add_argument('--target', type=str, help='Name of target column')
    parser.add_argument('--demo', action='store_true', help='Run demo with synthetic data')
    parser.add_argument('--scale', action='store_true', help='Standardize features')
    parser.add_argument('--subsample', type=int, default=None, help='Subsample size')
    args = parser.parse_args()
    
    if args.demo or args.data is None:
        print("Running demo with synthetic heavy-tailed data...\n")
        
        np.random.seed(42)
        n = 500
        d = 10
        
        # Class 0: smaller spread
        X_0 = np.random.standard_t(df=3, size=(n, d))
        
        # Class 1: larger spread, shifted
        X_1 = np.random.standard_t(df=3, size=(n, d)) * 2 + 1
        
        X = np.vstack([X_0, X_1])
        y = np.array([0] * n + [1] * n)
        
        diagnose_dataset(X, y)
    else:
        import pandas as pd
        from sklearn.preprocessing import StandardScaler
        from sklearn.model_selection import train_test_split
        
        df = pd.read_csv(args.data)
        
        # Find target column
        target_col = None
        for col in df.columns:
            if col.lower() == args.target.lower():
                target_col = col
                break
        
        if target_col is None:
            raise ValueError(f"Target '{args.target}' not found. Available: {list(df.columns)}")
        
        y = df[target_col].values
        X = df.drop(columns=[target_col]).values
        
        print(f"Loaded: n={X.shape[0]}, d={X.shape[1]}")
        
        # Subsample
        if args.subsample and args.subsample < len(y):
            X, _, y, _ = train_test_split(X, y, train_size=args.subsample, 
                                          stratify=y, random_state=42)
            print(f"Subsampled to n={args.subsample}")
        
        # Scale
        if args.scale:
            scaler = StandardScaler()
            X = scaler.fit_transform(X)
            print("Applied StandardScaler")
        
        diagnose_dataset(X, y)


if __name__ == "__main__":
    main()
