#!/usr/bin/env python3
"""
Risk-aligned statistical validation for CNCRC vs baseline methods.

Key improvement: Implements risk alignment strategy where all methods
are calibrated to achieve the same target non-coverage risk R0.

This matches the successful experimental design from EXPERIMENT_LOG_REFINED.md
"""

import os
import sys
import numpy as np
import pandas as pd
from scipy import stats
from typing import Dict, List, Tuple, Optional
import json
from datetime import datetime
import logging

# Silence verbose logs
logging.getLogger().setLevel(logging.WARNING)

sys.path.append(os.path.abspath('.'))

from src.cncrc.core.calibration import calibrate_quantile
from src.cncrc.core.risk_weighted_score import calculate_risk_weighted_score
from src.cncrc.drug_task.fixed_realistic_predictor import create_fixed_trained_predictor
from src.cncrc.drug_task.cost_mapping import create_mock_cost_matrix
from src.cncrc.drug_task.non_coverage_cost import (
    build_table_from_icd_guidelines,
    export_cost_nc_vector,
    NonCoverageCostConfig,
)
from src.cncrc.core.data_structures import ClinicalContext


def generate_samples(predictor, n_samples: int, seed: int = 0) -> List[Dict]:
    """Generate clinical samples for experiments."""
    rng = np.random.default_rng(seed)
    all_diagnoses = ["hypertension", "diabetes", "atrial_fibrillation", "heart_failure", "pneumonia", "copd"]

    samples = []
    for i in range(n_samples):
        age = int(rng.uniform(25, 85))
        gender = rng.choice(['M', 'F'])
        diagnosis = rng.choice(all_diagnoses)

        context = ClinicalContext(
            patient_id=f"P{i:04d}",
            age=age,
            gender=gender,
            diagnoses=[diagnosis]
        )

        probs = predictor.predict_probabilities(context)
        true_label = rng.choice(len(probs), p=probs)

        samples.append({
            'context': context,
            'probs': probs,
            'true_label': true_label
        })

    return samples


def find_alpha_for_target_risk(samples: List[Dict], cost_matrix: np.ndarray, cost_nc: np.ndarray,
                              method: str, target_risk: float, tau: float = 0.0) -> Tuple[float, Dict]:
    """Find alpha value that achieves target non-coverage risk."""
    n_cal = len(samples) // 3
    n_val = len(samples) // 3

    cal_samples = samples[:n_cal]
    val_samples = samples[n_cal:n_cal+n_val]

    # Try different alpha values
    alpha_candidates = np.linspace(0.01, 0.50, 50)
    best_alpha = None
    best_result = None
    min_risk_diff = float('inf')

    for alpha in alpha_candidates:
        # Calculate calibration scores
        cal_scores = []
        for sample in cal_samples:
            probs = sample['probs']
            y_true = sample['true_label']

            if method == 'standard_cp':
                score = 1 - probs[y_true]
            elif method == 'cncrc_max':
                score = calculate_risk_weighted_score(probs, cost_matrix, y_true)
            elif method == 'cncrc_sum':
                score = sum(probs[j] * cost_matrix[y_true, j] for j in range(len(probs)) if j != y_true)
            else:
                raise ValueError(f"Unknown method: {method}")

            cal_scores.append(score)

        # Calibrate threshold
        q = calibrate_quantile(cal_scores, alpha)

        # Evaluate on validation set
        nc_risks = []
        for sample in val_samples:
            probs = sample['probs']
            y_true = sample['true_label']

            # Build prediction set
            pred_set = []
            for y in range(len(probs)):
                if method == 'standard_cp':
                    score = 1 - probs[y]
                elif method == 'cncrc_max':
                    score = calculate_risk_weighted_score(probs, cost_matrix, y)
                elif method == 'cncrc_sum':
                    score = sum(probs[j] * cost_matrix[y, j] for j in range(len(probs)) if j != y)

                if score <= q and probs[y] >= tau:
                    pred_set.append(y)

            # Non-coverage risk
            if y_true not in pred_set:
                nc_risks.append(cost_nc[y_true])
            else:
                nc_risks.append(0.0)

        actual_risk = np.mean(nc_risks)
        risk_diff = abs(actual_risk - target_risk)

        if risk_diff < min_risk_diff:
            min_risk_diff = risk_diff
            best_alpha = alpha
            best_result = {
                'alpha': alpha,
                'threshold': q,
                'actual_risk': actual_risk,
                'target_risk': target_risk,
                'risk_diff': risk_diff
            }

    return best_alpha, best_result


def evaluate_method_with_alpha(samples: List[Dict], cost_matrix: np.ndarray, cost_nc: np.ndarray,
                              method: str, alpha: float, tau: float = 0.0) -> Dict:
    """Evaluate method with given alpha on test set."""
    n_cal = len(samples) // 3
    n_val = len(samples) // 3

    cal_samples = samples[:n_cal]
    test_samples = samples[n_cal+n_val:]

    # Calculate calibration scores
    cal_scores = []
    for sample in cal_samples:
        probs = sample['probs']
        y_true = sample['true_label']

        if method == 'standard_cp':
            score = 1 - probs[y_true]
        elif method == 'cncrc_max':
            score = calculate_risk_weighted_score(probs, cost_matrix, y_true)
        elif method == 'cncrc_sum':
            score = sum(probs[j] * cost_matrix[y_true, j] for j in range(len(probs)) if j != y_true)

        cal_scores.append(score)

    # Calibrate threshold
    q = calibrate_quantile(cal_scores, alpha)

    # Evaluate on test set
    coverages = []
    set_sizes = []
    nc_risks = []
    amb_costs = []

    for sample in test_samples:
        probs = sample['probs']
        y_true = sample['true_label']

        # Build prediction set
        pred_set = []
        for y in range(len(probs)):
            if method == 'standard_cp':
                score = 1 - probs[y]
            elif method == 'cncrc_max':
                score = calculate_risk_weighted_score(probs, cost_matrix, y)
            elif method == 'cncrc_sum':
                score = sum(probs[j] * cost_matrix[y, j] for j in range(len(probs)) if j != y)

            if score <= q and probs[y] >= tau:
                pred_set.append(y)

        # Metrics
        covered = y_true in pred_set
        coverages.append(covered)
        set_sizes.append(len(pred_set))

        # Non-coverage risk
        if not covered:
            nc_risks.append(cost_nc[y_true])
        else:
            nc_risks.append(0.0)

        # Ambiguity cost (if covered)
        if covered and len(pred_set) > 1:
            amb_cost = max(cost_matrix[y_true, y] for y in pred_set if y != y_true)
            amb_costs.append(amb_cost)
        else:
            amb_costs.append(0.0)

    return {
        'method': method,
        'alpha': alpha,
        'tau': tau,
        'threshold': q,
        'coverage': np.mean(coverages),
        'avg_set_size': np.mean(set_sizes),
        'nc_risk': np.mean(nc_risks),
        'amb_cost': np.mean(amb_costs),
        'amb_costs_raw': amb_costs,
        'coverages_raw': coverages,
        'set_sizes_raw': set_sizes,
        'nc_risks_raw': nc_risks
    }


def bootstrap_ci(data: np.ndarray, n_bootstrap: int = 1000, confidence: float = 0.95) -> Tuple[float, float]:
    """Calculate bootstrap confidence interval."""
    rng = np.random.default_rng(42)
    bootstrap_means = []

    for _ in range(n_bootstrap):
        bootstrap_sample = rng.choice(data, size=len(data), replace=True)
        bootstrap_means.append(np.mean(bootstrap_sample))

    alpha = 1 - confidence
    lower = np.percentile(bootstrap_means, 100 * alpha / 2)
    upper = np.percentile(bootstrap_means, 100 * (1 - alpha / 2))

    return lower, upper


def paired_t_test(data1: np.ndarray, data2: np.ndarray) -> Tuple[float, float]:
    """Perform paired t-test."""
    stat, p_value = stats.ttest_rel(data1, data2)
    return stat, p_value


def run_risk_aligned_validation():
    """Main risk-aligned statistical validation experiment."""
    print("🔬 Starting Risk-Aligned Statistical Validation for CNCRC")
    print("=" * 70)

    # Setup
    predictor = create_fixed_trained_predictor()
    n_drugs = len(predictor.drug_vocabulary)
    print(f"📊 Drug vocabulary size: {n_drugs}")

    # Generate cost matrices
    cost_matrix = create_mock_cost_matrix(predictor.drug_vocabulary, interaction_prob=0.2)

    # Generate Cost_NC
    config = NonCoverageCostConfig(normalization='max')
    cost_nc_table = build_table_from_icd_guidelines()
    cost_nc, _ = export_cost_nc_vector(predictor.drug_vocabulary, config=config, table=cost_nc_table)

    print(f"💊 Cost matrix: {cost_matrix.shape}, sparsity: {np.mean(cost_matrix == 0):.1%}")
    print(f"🎯 Cost_NC range: [{cost_nc.min():.3f}, {cost_nc.max():.3f}]")

    # Generate samples (larger for 3-way split)
    samples = generate_samples(predictor, n_samples=600, seed=123)
    print(f"📋 Generated {len(samples)} samples (Cal/Val/Test: 200/200/200)")

    # Methods to compare
    methods = ['standard_cp', 'cncrc_max', 'cncrc_sum']
    target_risks = [0.08, 0.10, 0.12]  # Target non-coverage risks

    results = []

    print("\n🎯 Risk Alignment Phase...")
    for target_risk in target_risks:
        print(f"\n📈 Target R_NC = {target_risk}")

        # Find optimal alpha for each method
        method_configs = {}
        for method in methods:
            print(f"  🔍 Finding optimal α for {method}...")

            if method.startswith('cncrc'):
                # Try with tau constraint
                best_alpha, config = find_alpha_for_target_risk(
                    samples, cost_matrix, cost_nc, method, target_risk, tau=0.05
                )
                method_configs[method] = {**config, 'tau': 0.05}
            else:
                best_alpha, config = find_alpha_for_target_risk(
                    samples, cost_matrix, cost_nc, method, target_risk, tau=0.0
                )
                method_configs[method] = {**config, 'tau': 0.0}

            print(f"    α={best_alpha:.3f}, actual R_NC={config['actual_risk']:.4f}, "
                  f"diff={config['risk_diff']:.4f}")

        # Evaluate all methods on test set
        print(f"\n🧪 Test Set Evaluation (Target R_NC = {target_risk}):")
        target_results = {}

        for method in methods:
            config = method_configs[method]
            result = evaluate_method_with_alpha(
                samples, cost_matrix, cost_nc, method,
                config['alpha'], config['tau']
            )
            target_results[method] = result
            results.append({**result, 'target_risk': target_risk})

            print(f"  {method:12s}: Coverage={result['coverage']:.3f}, "
                  f"APS={result['avg_set_size']:.2f}, "
                  f"R_NC={result['nc_risk']:.4f}, "
                  f"AmbCost={result['amb_cost']:.4f}")

        # Statistical tests for this target risk
        print(f"\n📊 Statistical Tests (Target R_NC = {target_risk}):")

        # CNCRC-Max vs Standard CP
        cncrc_amb = np.array(target_results['cncrc_max']['amb_costs_raw'])
        cp_amb = np.array(target_results['standard_cp']['amb_costs_raw'])

        if len(cncrc_amb) > 0 and len(cp_amb) > 0:
            t_stat, p_value = paired_t_test(cncrc_amb, cp_amb)
            improvement = (np.mean(cp_amb) - np.mean(cncrc_amb)) / np.mean(cp_amb) * 100

            print(f"  CNCRC-Max vs Standard CP:")
            print(f"    AmbCost improvement: {improvement:+.1f}%")
            print(f"    Paired t-test: t={t_stat:.3f}, p={p_value:.4f}")
            print(f"    Significant: {'Yes' if p_value < 0.05 else 'No'}")

        # CNCRC-Sum vs Standard CP
        cncrc_sum_amb = np.array(target_results['cncrc_sum']['amb_costs_raw'])

        if len(cncrc_sum_amb) > 0:
            t_stat, p_value = paired_t_test(cncrc_sum_amb, cp_amb)
            improvement = (np.mean(cp_amb) - np.mean(cncrc_sum_amb)) / np.mean(cp_amb) * 100

            print(f"  CNCRC-Sum vs Standard CP:")
            print(f"    AmbCost improvement: {improvement:+.1f}%")
            print(f"    Paired t-test: t={t_stat:.3f}, p={p_value:.4f}")
            print(f"    Significant: {'Yes' if p_value < 0.05 else 'No'}")

    # Generate summary table
    print("\n📋 Generating Risk-Aligned Statistical Summary...")

    # Create results DataFrame
    df = pd.DataFrame(results)

    # Add bootstrap CIs
    for idx, row in df.iterrows():
        amb_costs_raw = row['amb_costs_raw']
        if len(amb_costs_raw) > 0:
            ci_lower, ci_upper = bootstrap_ci(np.array(amb_costs_raw))
            df.at[idx, 'amb_cost_ci_lower'] = ci_lower
            df.at[idx, 'amb_cost_ci_upper'] = ci_upper

    # Save results
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = f"results/risk_aligned_validation_{timestamp}"
    os.makedirs(output_dir, exist_ok=True)

    # Save detailed results
    df.to_csv(f"{output_dir}/detailed_results.csv", index=False)

    # Create publication table
    pub_table = []
    for target_risk in target_risks:
        target_data = df[df['target_risk'] == target_risk]

        for _, row in target_data.iterrows():
            method = row['method']
            coverage = row['coverage']
            aps = row['avg_set_size']
            nc_risk = row['nc_risk']
            amb_cost = row['amb_cost']
            ci_lower = row.get('amb_cost_ci_lower', np.nan)
            ci_upper = row.get('amb_cost_ci_upper', np.nan)

            pub_table.append({
                'Target_R_NC': target_risk,
                'Method': method.replace('_', '-').upper(),
                'Coverage': f"{coverage:.3f}",
                'APS': f"{aps:.2f}",
                'Actual_R_NC': f"{nc_risk:.4f}",
                'AmbCost': f"{amb_cost:.4f}",
                'AmbCost_CI': f"[{ci_lower:.4f}, {ci_upper:.4f}]" if not np.isnan(ci_lower) else "N/A"
            })

    pub_df = pd.DataFrame(pub_table)
    pub_df.to_csv(f"{output_dir}/publication_table.csv", index=False)

    # Generate LaTeX table
    latex_table = pub_df.to_latex(index=False, float_format="%.4f")
    with open(f"{output_dir}/publication_table.tex", 'w') as f:
        f.write(latex_table)

    print(f"\n✅ Results saved to: {output_dir}")
    print(f"📊 Publication table: {output_dir}/publication_table.csv")
    print(f"📄 LaTeX table: {output_dir}/publication_table.tex")

    # Print key findings
    print("\n🎯 Key Risk-Aligned Statistical Findings:")
    print("-" * 50)

    # Find best performing method at target R_NC=0.10
    target_10_results = df[df['target_risk'] == 0.10]
    if not target_10_results.empty:
        best_method = target_10_results.loc[target_10_results['amb_cost'].idxmin()]

        print(f"Best method at target R_NC=0.10: {best_method['method'].upper()}")
        print(f"  AmbCost: {best_method['amb_cost']:.4f}")
        print(f"  Coverage: {best_method['coverage']:.3f}")
        print(f"  Actual R_NC: {best_method['nc_risk']:.4f}")

    return output_dir


if __name__ == "__main__":
    output_dir = run_risk_aligned_validation()
    print(f"\n🎉 Risk-aligned statistical validation completed!")
    print(f"📁 Results directory: {output_dir}")