import numpy as np
import torch
import json
import csv
from pathlib import Path
from collections import defaultdict
import argparse
from datetime import datetime
from scipy.special import softmax
import re

# Import evaluation utilities
from uncertainty_measures import (
    compute_classwise_scores,
    get_uncertainty_measures,
)
from utility_cal import (
    calculate_uc_top_class,
    calculate_uc_class_wise,
    calculate_uc_top_k_overall,
)


def format_result_with_std(mean, std):
    """Format result with mean and standard deviation, using 3 significant figures"""
    # Use numpy's around function with decimals=-3 to get 3 significant figures
    mean_sig = float(f"{mean:.3g}")
    std_sig = float(f"{std:.3g}")
    return f"{mean_sig} ± {std_sig}"


def load_model_data(model_path):
    """
    Load logits and labels from a model directory.
    Expected structure:
    model_path/
        {dataset_name}_logits.pt  # shape: (n_samples, n_classes)
        {dataset_name}_labels.pt  # shape: (n_samples,)
    where dataset_name is one of: cifar10, cifar100, imagenet
    """
    model_path = Path(model_path)

    # Find the dataset name from the directory name
    dataset_name = None
    if "CIFAR100" in model_path.name.upper():
        dataset_name = "cifar100"
    elif "CIFAR10" in model_path.name.upper():
        dataset_name = "cifar10"
    elif "ImageNet" in model_path.name:
        dataset_name = "imagenet"
    else:
        raise ValueError(f"Could not determine dataset name from path: {model_path}")

    logits_path = model_path / f"{dataset_name}_logits.pt"
    labels_path = model_path / f"{dataset_name}_labels.pt"

    if not logits_path.exists() or not labels_path.exists():
        raise FileNotFoundError(
            f"Could not find {logits_path.name} or {labels_path.name} in {model_path}"
        )

    # Load PyTorch tensors and convert to numpy
    logits = torch.load(logits_path)
    labels = torch.load(labels_path)

    # Convert to numpy if they're tensors
    if isinstance(logits, torch.Tensor):
        logits = logits.numpy()
    if isinstance(labels, torch.Tensor):
        labels = labels.numpy()

    # Basic validation
    if len(logits) != len(labels):
        raise ValueError(
            f"Number of logits ({len(logits)}) does not match number of labels ({len(labels)})"
        )

    return logits, labels


def evaluate_calibration(
    probs,
    labels,
    nbins=15,
    full_evaluation=True,
    uc_use_subsampling=True,
    uc_max_samples=1000,
    uc_num_subsamples=5,
    uc_seed=42,
):
    """
    Evaluate calibration using multiple metrics including utility calibration measures

    Args:
        probs: Probability predictions
        labels: True labels
        nbins: Number of bins for calibration metrics
        full_evaluation: If True, evaluate all metrics. If False, only evaluate utility calibration metrics
        uc_use_subsampling: Whether to use subsampling for utility calibration metrics
        uc_max_samples: Maximum number of samples to use for utility calibration metrics
        uc_num_subsamples: Number of subsamples to use for utility calibration metrics
        uc_seed: Random seed for utility calibration subsampling
    """
    probs_tensor = torch.from_numpy(probs)
    labels_tensor = torch.from_numpy(labels)

    metrics = {}

    if full_evaluation:
        # Get comprehensive uncertainty measures
        metrics = get_uncertainty_measures(probs_tensor, labels_tensor)

        # Add classwise scores with specific nbins
        cws_metric1, cws_metric2, cws_metric3 = compute_classwise_scores(
            probs_tensor, labels_tensor, nbins=nbins
        )
        metrics.update({"ece": cws_metric1, "mce": cws_metric2, "ace": cws_metric3})

    # Always compute utility calibration measures
    uc_top_class_err, _, _, _ = calculate_uc_top_class(
        probs,
        labels,
        use_subsampling=uc_use_subsampling,
        max_samples=uc_max_samples,
        num_subsamples=uc_num_subsamples,
        seed=uc_seed,
    )
    uc_class_wise_err, _, _, _, _ = calculate_uc_class_wise(
        probs,
        labels,
        use_subsampling=uc_use_subsampling,
        max_samples=uc_max_samples,
        num_subsamples=uc_num_subsamples,
        seed=uc_seed,
    )
    uc_top_k_err, _, _, _, _ = calculate_uc_top_k_overall(
        probs,
        labels,
        use_subsampling=uc_use_subsampling,
        max_samples=uc_max_samples,
        num_subsamples=uc_num_subsamples,
        seed=uc_seed,
    )

    metrics.update(
        {
            "uc_top_class": float(uc_top_class_err),
            "uc_class_wise": float(uc_class_wise_err),
            "uc_top_k": float(uc_top_k_err),
        }
    )

    return metrics


def recompute_metrics(
    model_dir,
    method_name_filter=None,
    uc_use_subsampling=True,
    uc_max_samples=1000,
    uc_num_subsamples=5,
    uc_seed=42,
):
    """
    Recompute metrics for both uncalibrated and calibrated models,
    aggregating results across multiple splits if available.

    Args:
        model_dir: Path to the model directory (e.g., ./logits/ResNet20_CIFAR10)
                   This directory should contain a 'results' subdirectory.
        method_name_filter: If specified, only process this method name.
        uc_use_subsampling: Whether to use subsampling for utility calibration metrics
        uc_max_samples: Maximum number of samples to use for utility calibration metrics
        uc_num_subsamples: Number of subsamples to use for utility calibration metrics
        uc_seed: Random seed for utility calibration subsampling

    Returns:
        Dictionary containing aggregated results (mean, std, max_dev) for each metric
        for each method found in the 'results' directory.
        Structure: {method_name: {metric_name: {'mean': ..., 'std': ..., 'max_dev': ...}}}
    """
    model_dir = Path(model_dir)
    results_dir = model_dir / "results"
    if not results_dir.exists():
        raise ValueError(f"No results directory found at {results_dir}")

    aggregated_results = defaultdict(dict)

    method_dirs = [d for d in results_dir.iterdir() if d.is_dir()]
    if method_name_filter:
        method_dirs = [d for d in method_dirs if d.name == method_name_filter]
        if not method_dirs:
            print(f"Specified method '{method_name_filter}' not found in {results_dir}")
            return aggregated_results
    elif not method_dirs:
        print(f"No method directories found in {results_dir}")
        return aggregated_results

    print(f"Found method directories: {[d.name for d in method_dirs]}")

    for method_dir in method_dirs:
        method_name = method_dir.name
        print(f"\nProcessing {method_name}...")

        split_metrics_list = []

        # Discover splits by looking for test_probs_*.npy files
        prob_files = sorted(list(method_dir.glob("test_probs_*.npy")))

        if not prob_files:
            print(
                f"No test probability files (test_probs_*.npy) found for {method_name}."
            )
            continue

        for test_probs_path in prob_files:
            # Extract split index from filename, e.g., test_probs_1.npy -> 1
            match = re.search(r"test_probs_(\d+)\.npy", test_probs_path.name)
            if not match:
                print(f"Could not parse split index from {test_probs_path.name}")
                continue
            split_idx_str = match.group(1)

            test_labels_path = method_dir / f"test_labels_{split_idx_str}.npy"

            if not test_labels_path.exists():
                print(
                    f"Labels file {test_labels_path.name} not found for split {split_idx_str} in {method_name}. Skipping this split."
                )
                continue

            print(f"  Processing split {split_idx_str}...")
            try:
                test_probs = np.load(test_probs_path)
                test_labels = np.load(test_labels_path)

                # Basic validation matching the one in load_model_data
                if len(test_probs) != len(test_labels):
                    print(
                        f"    Warning: Number of probabilities ({len(test_probs)}) does not match number of labels ({len(test_labels)}) for split {split_idx_str}. Skipping."
                    )
                    continue
                if test_probs.ndim == 2 and not np.all(
                    np.unique(test_labels) < test_probs.shape[1]
                ):
                    print(
                        f"    Warning: Labels for split {split_idx_str} seem out of range for {test_probs.shape[1]} classes. Skipping."
                    )
                    continue

                metrics = evaluate_calibration(
                    test_probs,
                    test_labels,
                    uc_use_subsampling=uc_use_subsampling,
                    uc_max_samples=uc_max_samples,
                    uc_num_subsamples=uc_num_subsamples,
                    uc_seed=uc_seed,
                )  # nbins is default
                split_metrics_list.append(metrics)
            except Exception as e:
                print(
                    f"    Error processing split {split_idx_str} for {method_name}: {e}"
                )

        if not split_metrics_list:
            print(f"No valid splits processed for {method_name}.")
            continue

        # Aggregate metrics for the current method
        collected_values_per_metric = defaultdict(list)
        for metrics_dict in split_metrics_list:
            for metric_name, value in metrics_dict.items():
                collected_values_per_metric[metric_name].append(value)

        for metric_name, values_list in collected_values_per_metric.items():
            if values_list:
                mean_val = float(np.mean(values_list))
                std_val = float(np.std(values_list))
                max_dev_val = float(np.max(values_list) - np.min(values_list))
                aggregated_results[method_name][metric_name] = {
                    "mean": mean_val,
                    "std": std_val,
                    "max_dev": max_dev_val,
                }
            else:
                aggregated_results[method_name][metric_name] = {
                    "mean": float("nan"),
                    "std": float("nan"),
                    "max_dev": float("nan"),
                }
        print(
            f"Finished processing {method_name}. Found {len(split_metrics_list)} valid split(s)."
        )

    return aggregated_results


def save_results_to_csv(results, output_path):
    """Save aggregated results to a CSV file, updating if it exists."""
    output_path = Path(output_path)
    existing_data = {}
    header = ["Method"]
    all_metric_names_set = set()

    if output_path.exists():
        try:
            with open(output_path, "r", newline="") as csvfile:
                reader = csv.reader(csvfile)
                header = next(reader)  # Read existing header
                all_metric_names_set.update(header[1:]) # Exclude 'Method'
                for row in reader:
                    if not row: continue
                    method_name = row[0]
                    existing_data[method_name] = {header[i]: row[i] for i in range(1, len(row))}
        except Exception as e:
            print(f"Warning: Could not properly read existing CSV file at {output_path}. It might be overwritten or appended to incorrectly. Error: {e}")
            # Reset header if reading failed badly, to avoid issues with malformed existing file
            header = ["Method"]
            existing_data = {}

    # Update with new results and collect all metric names
    for method_name, method_results in results.items():
        all_metric_names_set.update(method_results.keys())
        # Prepare new row data, formatting as needed
        current_method_data = {}
        for metric_name, stats in method_results.items():
            mean_val = stats["mean"]
            std_val = stats["std"]
            max_dev_val = stats["max_dev"]
            mean_str = f"{mean_val:.3g}"
            std_str = f"{std_val:.3g}"
            max_dev_str = f"{max_dev_val:.3g}"
            current_method_data[metric_name] = f"{mean_str} ± {std_str} (max_dev: {max_dev_str})"
        existing_data[method_name] = current_method_data # Overwrite or add new method data

    sorted_metric_names = sorted(list(all_metric_names_set))
    final_header = ["Method"] + sorted_metric_names

    with open(output_path, "w", newline="") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(final_header)

        for method_name, metrics_dict in sorted(existing_data.items()):
            row = [method_name]
            for metric_name in sorted_metric_names:
                row.append(metrics_dict.get(metric_name, "N/A"))  # Use N/A if metric is missing for a method
            writer.writerow(row)


def print_results(results):
    """Print aggregated results in a formatted table."""
    # Get all metric names from the new structure
    all_metric_names = set()
    for method_results in results.values():
        all_metric_names.update(method_results.keys())

    sorted_metric_names = sorted(list(all_metric_names))

    # Print header
    print(
        "\n--- Aggregated Calibration Results Summary (Mean ± Std (Max Deviation)) ---"
    )
    # Adjust column width based on typical output format
    # Example: "metric_name: mean ± std (max_dev)"
    # Let's aim for a reasonable width for each metric column. Max_dev adds to length.
    metric_col_width = 30  # Increased width to accommodate new format
    header_string = (
        "Method".ljust(20)
        + " | "
        + " | ".join(m.ljust(metric_col_width) for m in sorted_metric_names)
    )
    print(header_string)
    print(
        "-" * (20 + 3 + (metric_col_width + 3) * len(sorted_metric_names) - 3)
    )  # Adjust separator length

    # Print results for each method
    for method_name, aggregated_metrics in sorted(results.items()):
        row_parts = [method_name.ljust(20)]
        for metric_name in sorted_metric_names:
            if metric_name in aggregated_metrics:
                stats = aggregated_metrics[metric_name]
                mean_val = stats["mean"]
                std_val = stats["std"]
                max_dev_val = stats["max_dev"]
                # Format to 3 significant figures
                mean_str = f"{mean_val:.3g}"
                std_str = f"{std_val:.3g}"
                max_dev_str = f"{max_dev_val:.3g}"

                # Create the formatted string for the cell
                cell_str = f"{mean_str} ± {std_str} (md: {max_dev_str})"
                row_parts.append(cell_str.ljust(metric_col_width))
            else:
                row_parts.append("N/A".ljust(metric_col_width))
        print(" | ".join(row_parts))


def main():
    parser = argparse.ArgumentParser(
        description="Recompute calibration metrics from scratch"
    )
    parser.add_argument(
        "--model-dir",
        type=str,
        default="./logits/ViT_Base_P16_224_ImageNet1k",
        help="Path to the model directory containing logits/EfficientNet_B0_ImageNet1k",
    )
    parser.add_argument(
        "--output",
        type=str,
        default=None,
        help="Output CSV file path (default: model_dir/recomputed_metrics.csv)",
    )
    parser.add_argument(
        "--method-name",
        type=str,
        default=None,
        help="Specific method directory name to process (e.g., PostHocUC_...)")
    # Arguments for Utility Calibration subsampling
    parser.add_argument(
        "--uc-disable-subsampling",
        action="store_true",  # If present, set to True, default is False
        help="Disable subsampling for utility calibration metrics. Default is to use subsampling.",
    )
    parser.add_argument(
        "--uc-max-samples",
        type=int,
        default=1000,
        help="Maximum number of samples per subsample for utility calibration. Default: 1000",
    )
    parser.add_argument(
        "--uc-num-subsamples",
        type=int,
        default=5,
        help="Number of subsamples for utility calibration. Default: 5",
    )
    parser.add_argument(
        "--uc-seed",
        type=int,
        default=42,
        help="Random seed for utility calibration subsampling. Default: 42",
    )
    args = parser.parse_args()

    # Determine use_subsampling based on the flag
    uc_use_subsampling = not args.uc_disable_subsampling

    # Set default output path if not provided
    if args.output is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        args.output = str(Path(args.model_dir) / f"recomputed_metrics_{timestamp}.csv")

    print(f"Recomputing metrics for: {args.model_dir}")
    print(f"Output will be saved to: {args.output}")

    try:
        # Recompute metrics
        results = recompute_metrics(
            args.model_dir,
            method_name_filter=args.method_name,
            uc_use_subsampling=uc_use_subsampling,
            uc_max_samples=args.uc_max_samples,
            uc_num_subsamples=args.uc_num_subsamples,
            uc_seed=args.uc_seed,
        )

        # Print results
        print_results(results)

        # Save to CSV
        save_results_to_csv(results, args.output)
        print(f"\nResults saved to: {args.output}")

    except Exception as e:
        print(f"Error during recomputation: {str(e)}")
        import traceback

        traceback.print_exc()


if __name__ == "__main__":
    main()
