#!/usr/bin/env python3
"""
Optimized Alpha Sweep Experiment for CNCRC Visualization

Based on the successful quick validation test, this script runs the full alpha sweep
with fixed tau parameters to avoid the computational complexity of grid search.

Key optimizations:
1. Fixed tau=0.05 for CNCRC methods (proven to work well)
2. Simplified calibration process
3. Focus on the alpha values that matter most
4. Efficient data collection for visualization
"""

import os
import sys
import json
from dataclasses import dataclass, asdict
from typing import List, Dict, Tuple, Optional, Any
import numpy as np
import logging
from pathlib import Path

# Silence verbose logs
logging.getLogger().setLevel(logging.WARNING)

sys.path.append(os.path.abspath('.'))

from src.cncrc.core.calibration import calibrate_quantile
from src.cncrc.core.risk_weighted_score import calculate_risk_weighted_score
from src.cncrc.drug_task.fixed_realistic_predictor import create_fixed_trained_predictor
from src.cncrc.drug_task.cost_mapping import create_mock_cost_matrix
from src.cncrc.drug_task.non_coverage_cost import (
    build_table_from_icd_guidelines,
    NonCoverageCostConfig,
)
from src.cncrc.core.data_structures import ClinicalContext


@dataclass
class ExperimentResult:
    """Single experiment result for one method at one alpha level."""
    method: str
    alpha: float
    target_risk: float
    # Performance metrics
    coverage: float
    aps: float  # Average Prediction Set size
    non_coverage_risk: float
    ambiguity_cost_max: float
    # Bridging quantities
    q_threshold: float
    tau_threshold: Optional[float] = None
    q_over_tau: Optional[float] = None
    # Additional data for visualization
    score_distribution: List[float] = None
    set_sizes: List[int] = None
    method_params: Dict = None


@dataclass
class AlphaSweepResults:
    """Complete results from alpha sweep experiment."""
    alpha_values: List[float]
    methods: List[str]
    results: List[ExperimentResult]
    # Raw data for detailed analysis
    calibration_scores: Dict[str, List[float]]
    test_samples: List[Dict]
    cost_matrix: List[List[float]]
    cost_nc_vector: List[float]
    experiment_config: Dict


def calculate_cncrc_sum_score(probabilities: np.ndarray, cost_matrix: np.ndarray, y_true: int) -> float:
    """Calculate CNCRC-Sum score: sum of P(j|x) * Cost(y,j) for j≠y."""
    costs_for_y = cost_matrix[y_true, :]
    weighted_risks = probabilities * costs_for_y
    weighted_risks[y_true] = 0.0
    return float(np.sum(weighted_risks))


def calculate_cost_aware_score(probs: np.ndarray, cost_mat: np.ndarray, cost_nc: np.ndarray,
                             y_true: int, lambda_param: float) -> float:
    """Calculate Cost-Aware heuristic score."""
    uncertainty = 1.0 - probs[y_true]
    cost_penalty = cost_nc[y_true] + np.max(cost_mat[y_true, :])
    return uncertainty + lambda_param * cost_penalty


def build_sets_cp(probs: np.ndarray, q: float) -> List[int]:
    """Build prediction set for Standard CP."""
    scores = 1.0 - probs
    return [i for i in range(len(probs)) if scores[i] <= q]


def build_sets_cncrc_max(probs: np.ndarray, cost_mat: np.ndarray, q: float, tau: float = 0.05) -> List[int]:
    """Build prediction set for CNCRC-Max with fixed tau."""
    selected = []
    for y in range(len(probs)):
        if probs[y] < tau:
            continue

        score = calculate_risk_weighted_score(probs, cost_mat, y)
        if score <= q:
            selected.append(y)

    if not selected:
        best_idx = np.argmax(probs)
        selected = [best_idx]

    return selected


def build_sets_cncrc_sum(probs: np.ndarray, cost_mat: np.ndarray, q: float, tau: float = 0.05) -> List[int]:
    """Build prediction set for CNCRC-Sum with fixed tau."""
    selected = []
    for y in range(len(probs)):
        if probs[y] < tau:
            continue

        score = calculate_cncrc_sum_score(probs, cost_mat, y)
        if score <= q:
            selected.append(y)

    if not selected:
        best_idx = np.argmax(probs)
        selected = [best_idx]

    return selected


def build_sets_cost_aware(probs: np.ndarray, cost_mat: np.ndarray, cost_nc: np.ndarray,
                         q: float, lambda_param: float) -> List[int]:
    """Build prediction set for Cost-Aware heuristic."""
    selected = []
    for y in range(len(probs)):
        score = calculate_cost_aware_score(probs, cost_mat, cost_nc, y, lambda_param)
        if score <= q:
            selected.append(y)

    if not selected:
        best_idx = np.argmax(probs)
        selected = [best_idx]

    return selected


def evaluate_method(samples: List[Dict], sets_builder, q: float, cost_mat: np.ndarray,
                   cost_nc: np.ndarray, **kwargs) -> Tuple[Dict[str, float], List[int]]:
    """Evaluate method and return metrics + set sizes for visualization."""
    covered = 0
    total = len(samples)
    set_sizes = []
    nc_risk_sum = 0.0
    amb_cost_sum = 0.0
    amb_count = 0

    for s in samples:
        probs = s["probs"]
        y = s["y"]
        pred_set = sets_builder(probs, q, **kwargs)
        set_sizes.append(len(pred_set))

        if y in pred_set:
            covered += 1
            if len(pred_set) > 1:
                max_cost = 0.0
                for z in pred_set:
                    if z == y:
                        continue
                    max_cost = max(max_cost, float(cost_mat[y, z]))
                amb_cost_sum += max_cost
                amb_count += 1
        else:
            nc_risk_sum += float(cost_nc[y])

    coverage = covered / total
    aps = float(np.mean(set_sizes)) if set_sizes else 0.0
    nc_risk = nc_risk_sum / total
    amb_cost = (amb_cost_sum / amb_count) if amb_count > 0 else 0.0

    metrics = {
        "coverage": coverage,
        "aps": aps,
        "non_coverage_risk": nc_risk,
        "ambiguity_cost_max": amb_cost
    }

    return metrics, set_sizes


def find_optimal_q(val_samples: List[Dict], scores: np.ndarray, sets_builder,
                  cost_mat: np.ndarray, cost_nc: np.ndarray, target_risk: float, **kwargs) -> float:
    """Find optimal q to achieve target risk."""
    best_q = None
    best_gap = float('inf')

    # Simplified alpha grid
    alphas = [i / 100.0 for i in range(5, 31)]  # 0.05 to 0.30

    for alpha in alphas:
        q = calibrate_quantile(scores, alpha)
        metrics, _ = evaluate_method(val_samples, sets_builder, q, cost_mat, cost_nc, **kwargs)
        gap = abs(metrics['non_coverage_risk'] - target_risk)

        if gap < best_gap:
            best_gap = gap
            best_q = q

    return best_q


def run_optimized_alpha_sweep(alpha_values: List[float] = None,
                            output_dir: str = "results/alpha_sweep_optimized") -> AlphaSweepResults:
    """Run optimized alpha sweep experiment."""
    if alpha_values is None:
        alpha_values = [0.05, 0.06, 0.08, 0.10, 0.12, 0.15, 0.18, 0.20]

    print("=== Optimized Alpha Sweep Experiment for CNCRC Visualization ===")
    print(f"Alpha values: {alpha_values}")

    # Setup components
    predictor = create_fixed_trained_predictor()
    vocab = predictor.get_drug_vocabulary()
    cost_mat = create_mock_cost_matrix(vocab, interaction_prob=0.2)

    # Cost_NC vector (max normalized)
    table = build_table_from_icd_guidelines()
    cost_nc, _ = table.build_vector(vocab, NonCoverageCostConfig(normalization='max'))

    # Generate data splits (same seeds for consistency)
    def generate_samples(n_samples: int, seed: int) -> List[Dict]:
        rng = np.random.default_rng(seed)
        all_diagnoses = ["hypertension", "diabetes", "atrial_fibrillation", "heart_failure", "pneumonia", "copd"]
        samples = []

        for i in range(n_samples):
            age = int(rng.uniform(25, 85))
            gender = rng.choice(['M', 'F'])
            diagnosis = rng.choice(all_diagnoses)

            clinical_context = ClinicalContext(
                patient_id=f"P{i:04d}",
                age=age,
                gender=gender,
                diagnoses=[diagnosis]
            )

            probs = predictor.predict_probabilities(clinical_context)
            y = rng.choice(len(probs), p=probs)

            context = {
                "patient_id": f"P{i:04d}",
                "age": age,
                "gender": gender,
                "diagnosis": diagnosis
            }

            samples.append({
                "context": context,
                "probs": probs,
                "y": y
            })

        return samples

    cal_samples = generate_samples(200, seed=123)
    val_samples = generate_samples(200, seed=234)
    test_samples = generate_samples(200, seed=345)

    print(f"Generated {len(cal_samples)} cal, {len(val_samples)} val, {len(test_samples)} test samples")

    # Pre-calculate all scores for calibration
    calibration_scores = {}

    # Standard CP scores
    calibration_scores['Standard_CP'] = [1.0 - s["probs"][s["y"]] for s in cal_samples]

    # CNCRC-Max scores
    calibration_scores['CNCRC_Max'] = [
        calculate_risk_weighted_score(s["probs"], cost_mat, s["y"]) for s in cal_samples
    ]

    # CNCRC-Sum scores
    calibration_scores['CNCRC_Sum'] = [
        calculate_cncrc_sum_score(s["probs"], cost_mat, s["y"]) for s in cal_samples
    ]

    # Cost-Aware scores (using lambda=0.1 as default)
    calibration_scores['Cost_Aware'] = [
        calculate_cost_aware_score(s["probs"], cost_mat, cost_nc, s["y"], 0.1) for s in cal_samples
    ]

    print("Pre-calculated calibration scores for all methods")

    # Fixed parameters for efficiency
    tau_fixed = 0.05  # Based on successful validation
    lambda_fixed = 0.1  # Reasonable default

    # Run experiments for each alpha
    all_results = []

    for alpha in alpha_values:
        print(f"\n--- Alpha = {alpha:.3f} ---")

        # 1. Standard CP
        print("  Standard CP...")
        q_cp = find_optimal_q(
            val_samples, np.array(calibration_scores['Standard_CP']),
            lambda probs, q: build_sets_cp(probs, q),
            cost_mat, cost_nc, alpha
        )
        metrics_cp, sizes_cp = evaluate_method(
            test_samples, lambda probs, q: build_sets_cp(probs, q),
            q_cp, cost_mat, cost_nc
        )

        result_cp = ExperimentResult(
            method="Standard_CP",
            alpha=alpha,
            target_risk=alpha,
            coverage=metrics_cp["coverage"],
            aps=metrics_cp["aps"],
            non_coverage_risk=metrics_cp["non_coverage_risk"],
            ambiguity_cost_max=metrics_cp["ambiguity_cost_max"],
            q_threshold=q_cp,
            score_distribution=calibration_scores['Standard_CP'],
            set_sizes=sizes_cp,
            method_params={}
        )
        all_results.append(result_cp)

        # 2. CNCRC-Max
        print("  CNCRC-Max...")
        q_max = find_optimal_q(
            val_samples, np.array(calibration_scores['CNCRC_Max']),
            lambda probs, q: build_sets_cncrc_max(probs, cost_mat, q, tau_fixed),
            cost_mat, cost_nc, alpha
        )
        metrics_max, sizes_max = evaluate_method(
            test_samples, lambda probs, q: build_sets_cncrc_max(probs, cost_mat, q, tau_fixed),
            q_max, cost_mat, cost_nc
        )

        result_max = ExperimentResult(
            method="CNCRC_Max",
            alpha=alpha,
            target_risk=alpha,
            coverage=metrics_max["coverage"],
            aps=metrics_max["aps"],
            non_coverage_risk=metrics_max["non_coverage_risk"],
            ambiguity_cost_max=metrics_max["ambiguity_cost_max"],
            q_threshold=q_max,
            tau_threshold=tau_fixed,
            q_over_tau=q_max / tau_fixed,
            score_distribution=calibration_scores['CNCRC_Max'],
            set_sizes=sizes_max,
            method_params={"tau": tau_fixed}
        )
        all_results.append(result_max)

        # 3. CNCRC-Sum
        print("  CNCRC-Sum...")
        q_sum = find_optimal_q(
            val_samples, np.array(calibration_scores['CNCRC_Sum']),
            lambda probs, q: build_sets_cncrc_sum(probs, cost_mat, q, tau_fixed),
            cost_mat, cost_nc, alpha
        )
        metrics_sum, sizes_sum = evaluate_method(
            test_samples, lambda probs, q: build_sets_cncrc_sum(probs, cost_mat, q, tau_fixed),
            q_sum, cost_mat, cost_nc
        )

        result_sum = ExperimentResult(
            method="CNCRC_Sum",
            alpha=alpha,
            target_risk=alpha,
            coverage=metrics_sum["coverage"],
            aps=metrics_sum["aps"],
            non_coverage_risk=metrics_sum["non_coverage_risk"],
            ambiguity_cost_max=metrics_sum["ambiguity_cost_max"],
            q_threshold=q_sum,
            tau_threshold=tau_fixed,
            q_over_tau=q_sum / tau_fixed,
            score_distribution=calibration_scores['CNCRC_Sum'],
            set_sizes=sizes_sum,
            method_params={"tau": tau_fixed}
        )
        all_results.append(result_sum)

        # 4. Cost-Aware
        print("  Cost-Aware...")
        q_ca = find_optimal_q(
            val_samples, np.array(calibration_scores['Cost_Aware']),
            lambda probs, q: build_sets_cost_aware(probs, cost_mat, cost_nc, q, lambda_fixed),
            cost_mat, cost_nc, alpha
        )
        metrics_ca, sizes_ca = evaluate_method(
            test_samples, lambda probs, q: build_sets_cost_aware(probs, cost_mat, cost_nc, q, lambda_fixed),
            q_ca, cost_mat, cost_nc
        )

        result_ca = ExperimentResult(
            method="Cost_Aware",
            alpha=alpha,
            target_risk=alpha,
            coverage=metrics_ca["coverage"],
            aps=metrics_ca["aps"],
            non_coverage_risk=metrics_ca["non_coverage_risk"],
            ambiguity_cost_max=metrics_ca["ambiguity_cost_max"],
            q_threshold=q_ca,
            score_distribution=calibration_scores['Cost_Aware'],
            set_sizes=sizes_ca,
            method_params={"lambda": lambda_fixed}
        )
        all_results.append(result_ca)

        print(f"    Results: CP({metrics_cp['non_coverage_risk']:.3f}, APS={metrics_cp['aps']:.1f}), "
              f"Max({metrics_max['non_coverage_risk']:.3f}, APS={metrics_max['aps']:.1f}), "
              f"Sum({metrics_sum['non_coverage_risk']:.3f}, APS={metrics_sum['aps']:.1f}), "
              f"CA({metrics_ca['non_coverage_risk']:.3f}, APS={metrics_ca['aps']:.1f})")

    # Create comprehensive results object
    sweep_results = AlphaSweepResults(
        alpha_values=alpha_values,
        methods=["Standard_CP", "CNCRC_Max", "CNCRC_Sum", "Cost_Aware"],
        results=all_results,
        calibration_scores=calibration_scores,
        test_samples=test_samples,
        cost_matrix=cost_mat.tolist(),
        cost_nc_vector=cost_nc.tolist(),
        experiment_config={
            "n_cal": len(cal_samples),
            "n_val": len(val_samples),
            "n_test": len(test_samples),
            "interaction_prob": 0.2,
            "cost_nc_normalization": "max",
            "seeds": {"cal": 123, "val": 234, "test": 345},
            "tau_fixed": tau_fixed,
            "lambda_fixed": lambda_fixed
        }
    )

    # Save results
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, "alpha_sweep_results_optimized.json")

    # Convert to serializable format
    results_dict = asdict(sweep_results)

    # Convert numpy arrays to lists for JSON serialization
    def convert_numpy_to_list(obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, dict):
            return {k: convert_numpy_to_list(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [convert_numpy_to_list(item) for item in obj]
        else:
            return obj

    results_dict = convert_numpy_to_list(results_dict)

    with open(output_file, 'w') as f:
        json.dump(results_dict, f, indent=2)

    print(f"\nResults saved to: {output_file}")
    print(f"Total experiments: {len(all_results)}")

    return sweep_results


if __name__ == "__main__":
    results = run_optimized_alpha_sweep()
    print("Optimized alpha sweep experiment completed successfully!")