import numpy as np
import torch
from scipy.special import softmax
from tqdm import tqdm
import csv
import os
from collections import defaultdict
import argparse
from pathlib import Path
import time
from datetime import datetime
import json
from sklearn.metrics import make_scorer, log_loss

# Import calibration methods
from calibration_methods.temperature_scaling import TemperatureScaling
from calibration_methods.vector_scaling import VectorScaling
from calibration_methods.matrix_scaling import MatrixScaling, MatrixScalingODIR
from calibration_methods.dirichlet_calibration import DirichletL2, DirichletODIR
from calibration_methods.ensemble_temperature_scaling import EnsembleTemperatureScaling
from calibration_methods.irova import IROvA
from calibration_methods.irova_ts import IROvATS
from calibration_methods.irm import IRM
from calibration_methods.irm_ts import IRMTS


# Define negative log loss scorer for Bayesian optimization
def neg_log_loss(y_true, y_pred):
    """Negative log loss for Bayesian optimization (to maximize)"""
    return -log_loss(y_true, y_pred)


neg_log_loss_scorer = make_scorer(neg_log_loss, greater_is_better=True)

# Import evaluation utilities
from uncertainty_measures import (
    compute_classwise_scores,
    get_uncertainty_measures,
)


def load_model_data(model_path):
    """
    Load logits and labels from a model directory.
    Expected structure:
    model_path/
        {dataset_name}_logits.pt  # shape: (n_samples, n_classes)
        {dataset_name}_labels.pt  # shape: (n_samples,)
    where dataset_name is one of: cifar10, cifar100, imagenet
    """
    model_path = Path(model_path)

    # Find the dataset name from the directory name
    dataset_name = None
    if "CIFAR100" in model_path.name.upper():
        dataset_name = "cifar100"
    elif "CIFAR10" in model_path.name.upper():
        dataset_name = "cifar10"
    elif "ImageNet" in model_path.name:
        dataset_name = "imagenet"
    else:
        raise ValueError(f"Could not determine dataset name from path: {model_path}")

    logits_path = model_path / f"{dataset_name}_logits.pt"
    labels_path = model_path / f"{dataset_name}_labels.pt"

    if not logits_path.exists() or not labels_path.exists():
        raise FileNotFoundError(
            f"Could not find {logits_path.name} or {labels_path.name} in {model_path}"
        )

    # Load PyTorch tensors and convert to numpy
    logits = torch.load(logits_path)
    labels = torch.load(labels_path)

    # Convert to numpy if they're tensors
    if isinstance(logits, torch.Tensor):
        logits = logits.numpy()
    if isinstance(labels, torch.Tensor):
        labels = labels.numpy()

    # Basic validation
    if len(logits) != len(labels):
        raise ValueError(
            f"Number of logits ({len(logits)}) does not match number of labels ({len(labels)})"
        )

    n_classes = logits.shape[1]
    if not np.all(np.unique(labels) < n_classes):
        raise ValueError(f"Labels should be in range [0, {n_classes-1}]")

    return logits, labels


def split_data(logits, labels, val_ratio=0.5, random_seed=42):
    """
    Split data into validation (calibration) and test sets.
    """
    n_samples = len(logits)
    indices = np.random.RandomState(random_seed).permutation(n_samples)
    val_size = int(n_samples * val_ratio)

    val_indices = indices[:val_size]
    test_indices = indices[val_size:]

    val_logits = logits[val_indices]
    val_labels = labels[val_indices]
    test_logits = logits[test_indices]
    test_labels = labels[test_indices]

    return (val_logits, val_labels), (test_logits, test_labels)


def evaluate_calibration(probs, labels, nbins=15):
    """
    Evaluate calibration using uncertainty measures

    Args:
        probs: Probability predictions
        labels: True labels
        nbins: Number of bins for calibration metrics
    """
    probs_tensor = torch.from_numpy(probs)
    labels_tensor = torch.from_numpy(labels)

    # Get comprehensive uncertainty measures
    metrics = get_uncertainty_measures(probs_tensor, labels_tensor)

    return metrics


def get_calibrated_probs(calibrator, logits_tensor, method_name):
    """
    Helper function to get calibrated probabilities from different types of calibrators
    """
    try:
        if method_name in [
            "TemperatureScaling",
            "TemperatureScalingMSE",
            "VectorScaling",
            "MatrixScaling",
            "MatrixScalingODIR",
            "DirichletL2",
            "DirichletODIR",
            "EnsembleTemperatureScaling",
            "IRM",
            "IRMTS",
            "IROvA",
            "IROvATS",
        ]:
            probs_cal = calibrator.predict_proba(logits_tensor)
        else:
            raise NotImplementedError(
                f"Calibrator {method_name} is not supported in get_calibrated_probs."
            )

        if isinstance(probs_cal, torch.Tensor):
            probs_cal = probs_cal.detach().cpu().numpy()
        return probs_cal
    except Exception as e:
        print(f"Error getting calibrated probabilities for {method_name}: {str(e)}")
        return None


def fit_calibrator(calibrator, val_logits_tensor, val_labels_tensor, method_name):
    """
    Helper function to fit different types of calibrators
    """
    try:
        gridsearch_methods = ["MatrixScalingODIR", "DirichletL2", "DirichletODIR"]

        if method_name in gridsearch_methods:
            # BayesSearchCV (used by these) typically expects NumPy arrays.
            # Ensure data is on CPU before converting to NumPy.
            calibrator.fit(
                val_logits_tensor.cpu().numpy(), val_labels_tensor.cpu().numpy()
            )
        else:
            # For other methods, pass tensors directly.
            # The individual calibrator's fit method is now responsible for device management.
            if hasattr(calibrator, "predict_proba"):  # Most common case
                calibrator.fit(val_logits_tensor, val_labels_tensor.long())
            else:  # Fallback for methods without predict_proba, or different signature
                calibrator.fit(val_logits_tensor.float(), val_labels_tensor.float())
        return True
    except Exception as e:
        print(f"Error during fitting {method_name}: {str(e)}")
        return False


def format_result_with_std(mean, std):
    """Format result with mean and standard deviation, rounded to 3 significant figures"""
    return f"{np.round(mean, 3)} ± {np.round(std, 3)}"


def create_results_directory(model_path):
    """
    Create a results directory structure for saving calibration outputs.
    Results are saved directly in method directories, overwriting any previous results.

    Args:
        model_path: Path to the model directory

    Returns:
        Path to the results directory
    """
    # Create main results directory
    results_dir = Path(model_path) / "results"
    results_dir.mkdir(parents=True, exist_ok=True)

    # Create method directories directly under results/
    calibration_methods = [
        "Uncalibrated",
        "TemperatureScaling",
        "VectorScaling",
        "MatrixScaling",
        "MatrixScalingODIR",
        "DirichletL2",
        "DirichletODIR",
        "EnsembleTemperatureScaling",
        "IRM",
        "IRMTS",
        "IROvA",
        "IROvATS",
    ]

    for method in calibration_methods:
        method_dir = results_dir / method
        method_dir.mkdir(exist_ok=True)

    return results_dir


def save_calibration_outputs(
    results_dir,
    method_name,
    split_idx,
    val_probs=None,
    test_probs=None,
    val_labels=None,
    test_labels=None,
    metrics=None,
):
    """
    Save calibration outputs for a specific method and split.
    Results are saved with split index in filenames to distinguish between different splits.

    Args:
        results_dir: Path to the results directory
        method_name: Name of the calibration method
        split_idx: Index of the current split (0-based)
        val_probs: Calibrated probabilities for validation set
        test_probs: Calibrated probabilities for test set
        val_labels: Validation set labels
        test_labels: Test set labels
        metrics: Dictionary of evaluation metrics
    """
    # Save directly to method directory with split index in filenames
    method_dir = results_dir / method_name

    # Add split index to filenames
    split_suffix = f"_{split_idx + 1}"  # 1-based index for readability

    # Save probabilities if provided
    if val_probs is not None:
        np.save(method_dir / f"val_probs{split_suffix}.npy", val_probs)
    if test_probs is not None:
        np.save(method_dir / f"test_probs{split_suffix}.npy", test_probs)

    # Save labels if provided
    if val_labels is not None:
        np.save(method_dir / f"val_labels{split_suffix}.npy", val_labels)
    if test_labels is not None:
        np.save(method_dir / f"test_labels{split_suffix}.npy", test_labels)

    # Save metrics if provided, with split index in metric names
    if metrics is not None:
        # Add split index to metric names
        metrics_with_split = {f"{k}{split_suffix}": v for k, v in metrics.items()}
        with open(method_dir / f"metrics{split_suffix}.json", "w") as f:
            json.dump(metrics_with_split, f, indent=2)


def print_metrics(metrics, method_name, split_idx=None):
    """
    Print calibration metrics in a formatted way.

    Args:
        metrics: Dictionary of metrics to print
        method_name: Name of the method being evaluated
        split_idx: Optional split index to include in the output
    """
    split_info = f" (Split {split_idx+1})" if split_idx is not None else ""
    print(f"\n--- {method_name}{split_info} Metrics ---")

    # Sort metrics by name, but keep split-specific metrics together
    metric_items = sorted(
        metrics.items(), key=lambda x: (x[0].rsplit("_", 1)[0], x[0])
    )  # Sort by base metric name, then full name

    for metric_name, value in metric_items:
        print(f"{metric_name:20s}: {value:.4f}")
    print("-" * 50)


def main():
    start_time = time.time()
    parser = argparse.ArgumentParser(
        description="Evaluate calibration methods on real model outputs"
    )
    parser.add_argument(
        "--model-path",
        type=str,
        default="",
        help="Path to model directory containing logits.npy and labels.npy",
    )
    parser.add_argument(
        "--val-ratio",
        type=float,
        default=0.7,
        help="Ratio of data to use for validation/calibration",
    )
    parser.add_argument(
        "--n-splits", type=int, default=5, help="Number of random splits to evaluate"
    )
    parser.add_argument(
        "--output",
        type=str,
        default="calibration_results.csv",
        help="Output CSV file path",
    )
    args = parser.parse_args()

    # Set random seeds for reproducibility
    np.random.seed(42)
    torch.manual_seed(42)

    # Set up device
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"Using CUDA device: {torch.cuda.get_device_name(0)}")
        print("\nAvailable CUDA devices:")
        for i in range(torch.cuda.device_count()):
            print(f"  Device {i}: {torch.cuda.get_device_name(i)}")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
        print("Using MPS device (Apple Silicon)")
    else:
        device = torch.device("cpu")
        print("Using CPU device")

    print("\nAll available devices:")
    print(f"  CPU: {torch.device('cpu')}")
    if torch.cuda.is_available():
        print(f"  CUDA: {torch.device('cuda')}")
    if torch.backends.mps.is_available():
        print(f"  MPS: {torch.device('mps')}")
    print(f"  Selected device: {device}\n")

    print(f"Loading data from: {args.model_path}")
    logits, labels = load_model_data(args.model_path)
    n_classes = logits.shape[1]
    print(f"Loaded {len(logits)} samples with {n_classes} classes")
    print(f"Data loading took {time.time() - start_time:.2f} seconds\n")

    # Define calibration methods to try
    calibration_methods = {
        "TemperatureScaling": (TemperatureScaling, {"loss": "mse"}),
        "VectorScaling": (VectorScaling, {}),
        "MatrixScaling": (MatrixScaling, {}),
        "EnsembleTemperatureScaling": (EnsembleTemperatureScaling, {"loss": "mse"}),
        "IRM": (IRM, {}),
        "IRMTS": (IRMTS, {}),
        "IROvA": (IROvA, {}),
        "IROvATS": (IROvATS, {}),
        "DirichletL2": (DirichletL2, {"num_classes": n_classes}),
        "DirichletODIR": (DirichletODIR, {"num_classes": n_classes}),
    }

    # Results storage
    all_results = {"Uncalibrated": defaultdict(list)}
    for method_name in calibration_methods.keys():
        all_results[method_name] = defaultdict(list)

    # Create results directory
    results_dir = create_results_directory(args.model_path)
    print(f"Created results directory at: {results_dir}")

    # Save experiment configuration
    config = {
        "model_path": str(args.model_path),
        "val_ratio": args.val_ratio,
        "n_splits": args.n_splits,
        "timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
        "random_seed": 42,
    }
    with open(results_dir / "config.json", "w") as f:
        json.dump(config, f, indent=2)

    # Run multiple splits
    for split in range(args.n_splits):
        split_start_time = time.time()
        print(f"\n--- Running Split {split+1}/{args.n_splits} ---")
        print(f"Started at: {datetime.now().strftime('%H:%M:%S')}")

        # Split data for this iteration
        print("Splitting data into validation and test sets...")
        (val_logits, val_labels), (test_logits, test_labels) = split_data(
            logits, labels, val_ratio=args.val_ratio, random_seed=42 + split
        )
        print(f"Split sizes - Validation: {len(val_logits)}, Test: {len(test_logits)}")

        # Convert to torch tensors
        print("Converting data to PyTorch tensors...")
        val_logits_tensor = torch.from_numpy(val_logits).float().to(device)
        val_labels_tensor = torch.from_numpy(val_labels).long().to(device)
        test_logits_tensor = torch.from_numpy(test_logits).float().to(device)
        test_labels_tensor = torch.from_numpy(test_labels).long().to(device)

        # Save split indices for reproducibility
        np.save(
            results_dir / f"split_{split}_val_indices.npy",
            np.where(np.isin(labels, val_labels))[0],
        )
        np.save(
            results_dir / f"split_{split}_test_indices.npy",
            np.where(np.isin(labels, test_labels))[0],
        )

        # Evaluate uncalibrated performance
        print("\nEvaluating uncalibrated performance...")
        uncal_start_time = time.time()
        probs_uncal = softmax(test_logits, axis=1)
        metrics_uncal = evaluate_calibration(probs_uncal, test_labels)
        print(
            f"Uncalibrated evaluation took {time.time() - uncal_start_time:.2f} seconds"
        )
        print_metrics(metrics_uncal, "Uncalibrated", split)

        # Save uncalibrated results immediately
        save_calibration_outputs(
            results_dir,
            "Uncalibrated",
            split,
            val_probs=softmax(val_logits, axis=1),
            test_probs=probs_uncal,
            val_labels=val_labels,
            test_labels=test_labels,
            metrics=metrics_uncal,
        )

        # Store uncalibrated results
        for metric_name, value in metrics_uncal.items():
            all_results["Uncalibrated"][metric_name].append(value)

        # Try each calibration method
        print("\n--- Testing Calibration Methods ---")
        for method_name, (CalibratorClass, init_kwargs) in calibration_methods.items():
            method_start_time = time.time()
            print(f"\nCalibrating with: {method_name}")
            print(f"Started at: {datetime.now().strftime('%H:%M:%S')}")

            try:
                # Initialize and fit calibrator
                print(f"Initializing {method_name}...")
                calibrator = CalibratorClass(**init_kwargs)

                print(f"Fitting {method_name}...")
                fit_start_time = time.time()
                if not fit_calibrator(
                    calibrator, val_logits_tensor, val_labels_tensor, method_name
                ):
                    print(f"Failed to fit {method_name}")
                    continue
                print(f"Fitting took {time.time() - fit_start_time:.2f} seconds")

                # Get calibrated probabilities for both validation and test sets
                val_probs_cal = get_calibrated_probs(
                    calibrator, val_logits_tensor, method_name
                )
                test_probs_cal = get_calibrated_probs(
                    calibrator, test_logits_tensor, method_name
                )

                if val_probs_cal is None or test_probs_cal is None:
                    print(f"Failed to get calibrated probabilities for {method_name}")
                    continue

                # Evaluate calibration
                print(f"Evaluating calibration...")
                eval_start_time = time.time()
                metrics = evaluate_calibration(test_probs_cal, test_labels)
                print(f"Evaluation took {time.time() - eval_start_time:.2f} seconds")
                print_metrics(metrics, method_name, split)

                # Store results
                for metric_name, value in metrics.items():
                    all_results[method_name][metric_name].append(value)

                # Save results immediately after method completes
                save_calibration_outputs(
                    results_dir,
                    method_name,
                    split,
                    val_probs=val_probs_cal,
                    test_probs=test_probs_cal,
                    val_labels=val_labels,
                    test_labels=test_labels,
                    metrics=metrics,
                )

                print(
                    f"Completed {method_name} in {time.time() - method_start_time:.2f} seconds"
                )

            except Exception as e:
                print(f"Error with {method_name} on split {split+1}: {str(e)}")
                continue

        # After each split, compute and save current statistics
        print("\nComputing current statistics...")
        current_results = {}
        for method_name, metrics_dict in all_results.items():
            current_results[method_name] = {}
            # Group metrics by their base name (without split index)
            base_metrics = defaultdict(list)
            for metric_name, values in metrics_dict.items():
                # Handle both split-specific and non-split-specific metric names
                if "_" in metric_name and metric_name.split("_")[-1].isdigit():
                    # This is a split-specific metric (e.g., acc_3)
                    base_name = "_".join(metric_name.split("_")[:-1])
                else:
                    # This is a regular metric name
                    base_name = metric_name
                base_metrics[base_name].extend(values)

            # Compute statistics for each base metric
            for base_name, values in base_metrics.items():
                if values:
                    mean_val = np.mean(values)
                    std_val = np.std(values)
                    current_results[method_name][base_name] = (mean_val, std_val)
                else:
                    current_results[method_name][base_name] = (
                        float("nan"),
                        float("nan"),
                    )

        # Print current aggregate results
        print("\n--- Current Aggregate Results (Mean ± Std) ---")
        for method_name, metrics in current_results.items():
            print(f"\n{method_name}:")
            for metric_name, (mean, std) in sorted(metrics.items()):
                if not np.isnan(mean):
                    print(f"{metric_name:20s}: {mean:.4f} ± {std:.4f}")
            print("-" * 50)

        # Save current results to CSV after each split
        output_path = results_dir / "final_results.csv"
        with open(output_path, "w", newline="") as csvfile:
            writer = csv.writer(csvfile)
            # Get all metric names (without split indices)
            all_metrics = set()
            for method_results in current_results.values():
                all_metrics.update(method_results.keys())

            # Define a preferred order for metrics in the CSV
            preferred_metric_order = [
                "acc",
                "top1_eq_mass",  # Renamed from top1_ece_eq_mass
                "top1_eq_width",  # Renamed from top1_ece_eq_width
                "cw_eq_mass",  # Renamed from cw_ece_eq_mass
                "cw_eq_width",  # Renamed from cw_ece_eq_width
                "top1_brier",
                "cw_brier",
                "nll",
                "brier",
            ]

            # Filter available metrics based on preferred order and add any remaining
            sorted_metrics = [m for m in preferred_metric_order if m in all_metrics]
            remaining_metrics = sorted(
                [m for m in all_metrics if m not in preferred_metric_order]
            )
            final_metric_order = sorted_metrics + remaining_metrics

            writer.writerow(["Method"] + final_metric_order)
            for (
                method_name,
                metrics_data,
            ) in current_results.items():  # Renamed metrics to metrics_data
                row = [method_name]
                for metric in final_metric_order:
                    if metric in metrics_data:
                        mean, std = metrics_data[metric]
                        row.append(format_result_with_std(mean, std))
                    else:
                        row.append("N/A")
                writer.writerow(row)

        print(
            f"\nCompleted split {split+1} in {time.time() - split_start_time:.2f} seconds"
        )
        print(f"Current results saved to {output_path}")
        if split < args.n_splits - 1:
            print(
                f"Estimated time remaining: {(time.time() - start_time) / (split + 1) * (args.n_splits - split - 1):.2f} seconds"
            )

    print(f"\nTotal execution time: {time.time() - start_time:.2f} seconds")
    print(f"Final results saved to {output_path}")


if __name__ == "__main__":
    main()
