import numpy as np
import torch
import json
from pathlib import Path
from collections import defaultdict
import argparse
from datetime import datetime
from scipy.special import softmax
import re  # For parsing split numbers

# Removed: import matplotlib.pyplot as plt

# JAX for utility calibration calculations
import jax
import jax.numpy as jnp

# Import from existing project files
from utility_cal import (
    calculate_uc_linear_distribution,
    calculate_uc_rank_based_distribution,
)


def _extract_split_number(filename_stem):
    match = re.search(r"split_(\d+)_", filename_stem)
    if match:
        return int(match.group(1))
    match = re.search(r"_(\d+)$", filename_stem)  # for test_probs_1.npy or similar
    if match:
        return int(match.group(1))
    if filename_stem.isdigit():
        return int(filename_stem)
    return None


def load_model_data(model_path_str):
    model_path = Path(model_path_str)
    dataset_name = None
    possible_dataset_names = ["cifar100", "cifar10", "imagenet"]

    for name_part in model_path.name.upper().split("_"):
        if name_part in [ds.upper() for ds in possible_dataset_names]:
            dataset_name = name_part.lower()
            break

    if not dataset_name:
        for parent in model_path.parents:
            for name_part in parent.name.upper().split("_"):
                if name_part in [ds.upper() for ds in possible_dataset_names]:
                    dataset_name = name_part.lower()
                    break
            if dataset_name:
                break

    if not dataset_name:
        for ds_cand in possible_dataset_names:
            if (model_path / f"{ds_cand}_logits.pt").exists():
                dataset_name = ds_cand
                break

    if not dataset_name:
        pt_files = list(model_path.glob("*_logits.pt"))
        if pt_files:
            potential_ds_name = pt_files[0].name.replace("_logits.pt", "")
            if potential_ds_name in possible_dataset_names:
                dataset_name = potential_ds_name
        if not dataset_name and len(pt_files) > 0:
            dataset_name = pt_files[0].name.replace("_logits.pt", "")
            print(
                f"Warning: Dataset name heuristically guessed as '{dataset_name}' from file {pt_files[0].name}. Please verify."
            )

    if not dataset_name:
        raise ValueError(
            f"Could not determine dataset name from path: {model_path}. "
            f"Looked for {possible_dataset_names} in path components and for files like 'cifar10_logits.pt'."
        )
    print(f"Determined dataset name: {dataset_name} for model path {model_path}")

    logits_path = model_path / f"{dataset_name}_logits.pt"
    labels_path = model_path / f"{dataset_name}_labels.pt"
    current_search_path = model_path
    while (
        not logits_path.exists() and current_search_path.parent != current_search_path
    ):
        current_search_path = current_search_path.parent
        logits_path = current_search_path / f"{dataset_name}_logits.pt"
        labels_path = current_search_path / f"{dataset_name}_labels.pt"
        if logits_path.exists():
            print(f"Found logits/labels in parent directory: {current_search_path}")
            break

    if not logits_path.exists() or not labels_path.exists():
        raise FileNotFoundError(
            f"Could not find {dataset_name}_logits.pt or {dataset_name}_labels.pt. "
            f"Searched starting from {model_path} and its parents. Final paths: {logits_path}, {labels_path}"
        )

    logits = torch.load(logits_path, map_location=torch.device("cpu")).numpy()
    labels = torch.load(labels_path, map_location=torch.device("cpu")).numpy()

    if len(logits) != len(labels):
        raise ValueError("Number of logits and labels do not match.")

    return logits, labels, dataset_name, current_search_path


# --- Core Data Computation Logic ---


def compute_and_save_ecdf_data(
    model_dir_str,
    output_log_parent_dir_str,
    max_splits_to_eval,
    num_utility_samples,
    jax_seed,
    target_split_override=None,
    data_subsample_trigger_n=30000,
    data_subsample_batch_n=1000,
    data_num_subsample_batches=5,
    data_subsampling_seed=42,
    utility_sample_chunk_size=50,
):
    """
    Computes linear and rank-based utility calibration error distributions and saves them to .npy files.
    """
    print(f"DEBUG: compute_and_save_ecdf_data called with model_dir: {model_dir_str}")
    print(f"DEBUG:   num_utility_samples: {num_utility_samples}")
    print(f"DEBUG:   data_subsample_trigger_n: {data_subsample_trigger_n}")
    print(f"DEBUG:   data_subsample_batch_n: {data_subsample_batch_n}")
    print(f"DEBUG:   data_num_subsample_batches: {data_num_subsample_batches}")
    print(f"DEBUG:   utility_sample_chunk_size: {utility_sample_chunk_size}")
    try:
        all_logits, all_labels, dataset_name, effective_model_dir = load_model_data(
            model_dir_str
        )
    except Exception as e:
        print(f"Error loading data for {model_dir_str}: {e}")
        return

    results_dir = effective_model_dir / "results"
    if not results_dir.is_dir():
        print(
            f"Warning: No 'results' directory found at {results_dir}. Skipping data generation for {model_dir_str}"
        )
        return

    # Define the specific log directory for this model's ECDF data
    model_specific_log_dir = (
        Path(output_log_parent_dir_str) / effective_model_dir.name / "cdf_plot_logs"
    )
    model_specific_log_dir.mkdir(parents=True, exist_ok=True)
    print(f"Saving ECDF data logs to: {model_specific_log_dir}")

    master_key = jax.random.PRNGKey(jax_seed)

    # These will store the final ECDF error arrays (a single array per method)
    all_linear_errors_by_method = {}
    all_rank_errors_by_method = {}

    # Determine which splits' data to include based on CLI args
    split_indices_files_info = []
    if results_dir.is_dir():
        raw_split_files = sorted(list(results_dir.glob("split_*_test_indices.npy")))
        for f_path in raw_split_files:
            split_num_cand = _extract_split_number(f_path.stem)
            if split_num_cand is not None:
                split_indices_files_info.append((split_num_cand, f_path))
        split_indices_files_info.sort()

    if not split_indices_files_info:
        print(
            f"Warning: No 'split_*_test_indices.npy' files found in {results_dir}. Cannot select splits."
        )
        return

    selected_splits_for_aggregation = []
    if target_split_override is not None:
        found_target_split = False
        for sn, fp in split_indices_files_info:
            if sn == target_split_override:
                selected_splits_for_aggregation.append((sn, fp))
                found_target_split = True
                print(
                    f"  Aggregating data from specific split number: {target_split_override}."
                )
                break
        if not found_target_split:
            print(
                f"  Error: Target split number {target_split_override} for aggregation not found. Available: {[s[0] for s in split_indices_files_info]}"
            )
            return
    else:
        num_available_splits = len(split_indices_files_info)
        if max_splits_to_eval > 0 and num_available_splits > max_splits_to_eval:
            print(
                f"  Aggregating data from the first {max_splits_to_eval} splits out of {num_available_splits} available."
            )
            selected_splits_for_aggregation = split_indices_files_info[
                :max_splits_to_eval
            ]
        elif max_splits_to_eval == 0:  # 0 means all
            print(
                f"  Aggregating data from all {num_available_splits} available splits."
            )
            selected_splits_for_aggregation = split_indices_files_info
        else:  # max_splits_to_eval is positive and <= num_available_splits
            selected_splits_for_aggregation = split_indices_files_info[
                : min(max_splits_to_eval, num_available_splits)
            ]

    if not selected_splits_for_aggregation:
        print(
            f"No splits selected for data aggregation for {effective_model_dir.name}. Skipping data generation."
        )
        return

    print(
        f"Aggregating data from {len(selected_splits_for_aggregation)} splits for {effective_model_dir.name}."
    )

    # --- 1. Process "Uncalibrated" method by aggregating data from selected splits ---
    method_name_uncal = "Uncalibrated"
    print(f"\nProcessing method: {method_name_uncal} by aggregating data...")
    uncal_probs_list, uncal_labels_list = [], []

    for split_num, split_file_path in selected_splits_for_aggregation:
        print(
            f"  Aggregating Uncalibrated data from split {split_num} (indices: {split_file_path.name})"
        )
        current_test_indices = np.load(split_file_path)
        if current_test_indices.size == 0:
            print(
                f"    Warning: Split {split_num} has no test indices. Skipping for Uncalibrated."
            )
            continue

        split_logits = all_logits[current_test_indices]
        split_labels = all_labels[current_test_indices]

        if (
            split_logits.shape[0] == 0
            or split_labels.shape[0] == 0
            or split_logits.shape[0] != split_labels.shape[0]
        ):
            print(
                f"    Warning: Data issue for Uncalibrated in split {split_num} (Logits: {split_logits.shape}, Labels: {split_labels.shape}). Skipping."
            )
            continue

        if split_labels.ndim != 1:  # Ensure labels are 1D
            split_labels = np.argmax(split_labels, axis=1)

        uncal_probs_list.append(softmax(split_logits, axis=1))
        uncal_labels_list.append(split_labels)

    if uncal_probs_list and uncal_labels_list:
        final_uncal_probs = np.concatenate(uncal_probs_list, axis=0)
        final_uncal_labels = np.concatenate(uncal_labels_list, axis=0)
        print(
            f"  {method_name_uncal}: Aggregated {final_uncal_probs.shape[0]} samples. Computing ECDF..."
        )

        master_key, subkey_linear_uncal = jax.random.split(master_key)
        linear_errors_dist, _, _, _, _ = calculate_uc_linear_distribution(
            subkey_linear_uncal,
            final_uncal_probs,
            final_uncal_labels,
            num_utility_samples,
            utility_batch_size=utility_sample_chunk_size,
            use_subsampling=True,
            data_subsample_trigger_size=data_subsample_trigger_n,
            data_subsample_max_samples=data_subsample_batch_n,
            data_num_subsamples=data_num_subsample_batches,
            data_subsampling_seed=data_subsampling_seed,
        )
        all_linear_errors_by_method[method_name_uncal] = np.array(linear_errors_dist)

        master_key, subkey_rank_uncal = jax.random.split(master_key)
        rank_errors_dist, _, _, _, _ = calculate_uc_rank_based_distribution(
            subkey_rank_uncal,
            final_uncal_probs,
            final_uncal_labels,
            num_utility_samples,
            utility_batch_size=utility_sample_chunk_size,
            use_subsampling=True,
            data_subsample_trigger_size=data_subsample_trigger_n,
            data_subsample_max_samples=data_subsample_batch_n,
            data_num_subsamples=data_num_subsample_batches,
            data_subsampling_seed=data_subsampling_seed,  # Potentially vary seed for rank if needed
        )
        all_rank_errors_by_method[method_name_uncal] = np.array(rank_errors_dist)
        print(f"    {method_name_uncal}: ECDF computation complete.")
    else:
        print(f"  {method_name_uncal}: No data aggregated. Skipping ECDF computation.")

    # --- 2. Process Calibrated Methods by aggregating data from selected splits ---
    method_dirs = [
        d
        for d in results_dir.iterdir()
        if d.is_dir() and d.name != "Uncalibrated" and not d.name.startswith("run_")
    ]

    for method_dir_path in method_dirs:
        method_name_cal = method_dir_path.name
        print(f"\nProcessing method: {method_name_cal} by aggregating data...")
        cal_probs_list, cal_labels_list = [], []

        for (
            split_num,
            _,
        ) in (
            selected_splits_for_aggregation
        ):  # We only need split_num to find method files
            print(
                f"  Attempting to aggregate {method_name_cal} data from split {split_num}"
            )
            test_probs_path = method_dir_path / f"test_probs_{split_num}.npy"
            test_labels_path = method_dir_path / f"test_labels_{split_num}.npy"

            if not test_probs_path.exists():
                # print(f"    Probs file {test_probs_path.name} not found for {method_name_cal}, split {split_num}. Skipping this split for this method.")
                continue
            if not test_labels_path.exists():
                # print(f"    Labels file {test_labels_path.name} not found for {method_name_cal}, split {split_num}. Skipping this split for this method.")
                continue

            current_method_probs = np.load(test_probs_path)
            current_method_labels = np.load(test_labels_path)

            if (
                current_method_probs.shape[0] == 0
                or current_method_labels.shape[0] == 0
                or current_method_probs.shape[0] != current_method_labels.shape[0]
            ):
                print(
                    f"    Warning: Data issue for {method_name_cal} in split {split_num} (Probs: {current_method_probs.shape}, Labels: {current_method_labels.shape}). Skipping this split for this method."
                )
                continue

            if current_method_labels.ndim != 1:  # Ensure labels are 1D
                current_method_labels = np.argmax(current_method_labels, axis=1)

            cal_probs_list.append(current_method_probs)
            cal_labels_list.append(current_method_labels)
            print(
                f"    Aggregated {current_method_probs.shape[0]} samples from {method_name_cal}, split {split_num}"
            )

        if cal_probs_list and cal_labels_list:
            final_cal_probs = np.concatenate(cal_probs_list, axis=0)
            final_cal_labels = np.concatenate(cal_labels_list, axis=0)
            print(
                f"  {method_name_cal}: Aggregated {final_cal_probs.shape[0]} samples. Computing ECDF..."
            )

            master_key, subkey_linear_cal = jax.random.split(master_key)
            linear_errors_dist, _, _, _, _ = calculate_uc_linear_distribution(
                subkey_linear_cal,
                final_cal_probs,
                final_cal_labels,
                num_utility_samples,
                utility_batch_size=utility_sample_chunk_size,
                use_subsampling=True,
                data_subsample_trigger_size=data_subsample_trigger_n,
                data_subsample_max_samples=data_subsample_batch_n,
                data_num_subsamples=data_num_subsample_batches,
                data_subsampling_seed=data_subsampling_seed,
            )
            all_linear_errors_by_method[method_name_cal] = np.array(linear_errors_dist)

            master_key, subkey_rank_cal = jax.random.split(master_key)
            rank_errors_dist, _, _, _, _ = calculate_uc_rank_based_distribution(
                subkey_rank_cal,
                final_cal_probs,
                final_cal_labels,
                num_utility_samples,
                utility_batch_size=utility_sample_chunk_size,
                use_subsampling=True,
                data_subsample_trigger_size=data_subsample_trigger_n,
                data_subsample_max_samples=data_subsample_batch_n,
                data_num_subsamples=data_num_subsample_batches,
                data_subsampling_seed=data_subsampling_seed,  # Potentially vary seed
            )
            all_rank_errors_by_method[method_name_cal] = np.array(rank_errors_dist)
            print(f"    {method_name_cal}: ECDF computation complete.")
        else:
            print(
                f"  {method_name_cal}: No data aggregated. Skipping ECDF computation."
            )

    # --- Save Consolidated ECDF Data ---
    print(
        f"\nConsolidating and saving final ECDF data for {effective_model_dir.name}..."
    )

    for method_name, errors_array in all_linear_errors_by_method.items():
        if errors_array.size > 0:  # Check if the array is not empty
            save_path = model_specific_log_dir / f"{method_name}_linear_errors.npy"
            np.save(
                save_path, errors_array
            )  # errors_array is already the final concatenated array
            print(
                f"  Saved linear errors for '{method_name}' to {save_path} (shape: {errors_array.shape})"
            )
        else:
            print(
                f"Warning: No linear errors collected for method '{method_name}' to save (array was empty or not created)."
            )

    for method_name, errors_array in all_rank_errors_by_method.items():
        if errors_array.size > 0:  # Check if the array is not empty
            save_path = model_specific_log_dir / f"{method_name}_rank_errors.npy"
            np.save(
                save_path, errors_array
            )  # errors_array is already the final concatenated array
            print(
                f"  Saved rank errors for '{method_name}' to {save_path} (shape: {errors_array.shape})"
            )
        else:
            print(
                f"Warning: No rank errors collected for method '{method_name}' to save (array was empty or not created)."
            )

    # Save metadata (num_splits_processed should reflect number of splits data was aggregated from)
    metadata = {
        "model_name": effective_model_dir.name,
        "dataset_name": dataset_name,
        "num_splits_aggregated_from": len(selected_splits_for_aggregation),
        "target_split_override_used": target_split_override,
        "max_splits_config_used": (
            "All"
            if max_splits_to_eval == 0 and target_split_override is None
            else max_splits_to_eval
        ),
        "num_utility_samples": num_utility_samples,
        "jax_seed_master": jax_seed,
        "data_subsample_trigger_n": data_subsample_trigger_n,
        "data_subsample_batch_n": data_subsample_batch_n,
        "data_num_subsample_batches": data_num_subsample_batches,
        "data_subsampling_seed_base": data_subsampling_seed,
        "utility_sample_chunk_size": utility_sample_chunk_size,
        "timestamp": datetime.now().isoformat(),
    }
    metadata_path = model_specific_log_dir / "_metadata.json"
    with open(metadata_path, "w") as f:
        json.dump(metadata, f, indent=4)
    print(f"  Saved metadata to {metadata_path}")


def main():
    parser = argparse.ArgumentParser(
        description="Compute and save ECDF data (utility calibration error distributions) for models."
    )
    parser.add_argument(
        "--model-dir",
        type=str,
        default="./logits/ViT_Base_P16_224_ImageNet1k",
        help="Path to the main model directory (e.g., ./logits/ResNet56_CIFAR100).",
    )
    parser.add_argument(
        "--output-log-parent-dir",
        type=str, 
        default="./experiment_cdf_data",
        help="Parent directory to save the ECDF data logs. A subdirectory for each model will be created here. (Default: ./experiment_cdf_data).",
    )
    parser.add_argument(
        "--max-splits",
        type=int,
        default=1,
        help="Maximum number of data splits to evaluate. Default 1 (first split). 0 for all. This is ignored if --target-split-num is set.",
    )
    parser.add_argument(
        "--target-split-num",
        type=int,
        default=1,
        help="Specify a single, specific split number to process (e.g., 0, 1, 2...). If set, --max-splits is ignored.",
    )
    parser.add_argument(
        "--num-utility-samples",
        type=int,
        default=500,
        help="Number of utility vectors to sample for ECDF distributions (default: 100).",
    )
    parser.add_argument(
        "--jax-seed", type=int, default=42, help="Seed for JAX PRNGKey (default: 42)."
    )
    parser.add_argument(
        "--uc_data_subsample_trigger_n",
        type=int,
        default=5500,
        help="Data size threshold to trigger subsampling in ECDF calculations (default: 30000).",
    )
    parser.add_argument(
        "--uc_data_subsample_batch_n",
        type=int,
        default=5000,
        help="Target batch size for data subsampling in ECDF calculations (default: 1000).",
    )
    parser.add_argument(
        "--uc_data_num_subsample_batches",
        type=int,
        default=5,
        help="Number of subsample batches if subsampling is triggered (default: 5).",
    )
    parser.add_argument(
        "--uc_data_subsampling_seed",
        type=int,
        default=42,
        help="Seed for data subsampling if triggered (default: 42).",
    )
    parser.add_argument(
        "--uc_utility_sample_chunk_size",
        type=int,
        default=5,
        help="Chunk size for utility samples to avoid JAX memory issues (default: 50).",
    )
    args = parser.parse_args()

    Path(args.output_log_parent_dir).mkdir(parents=True, exist_ok=True)

    print(f"Starting ECDF data computation for model directory: {args.model_dir}")
    print(f"Configuration:")
    print(f"  Output log parent directory: {args.output_log_parent_dir}")
    if args.target_split_num is not None:
        print(f"  Targeting specific split: {args.target_split_num}")
    else:
        print(
            f"  Max splits to evaluate: {'All' if args.max_splits == 0 else args.max_splits}"
        )
    print(f"  Number of utility samples per distribution: {args.num_utility_samples}")
    print(f"  JAX PRNG seed: {args.jax_seed}")
    print(f"  Data subsample trigger N: {args.uc_data_subsample_trigger_n}")
    print(f"  Data subsample batch N: {args.uc_data_subsample_batch_n}")
    print(f"  Data num subsample batches: {args.uc_data_num_subsample_batches}")
    print(f"  Data subsampling seed: {args.uc_data_subsampling_seed}")
    print(f"  Utility sample chunk size: {args.uc_utility_sample_chunk_size}")

    try:
        compute_and_save_ecdf_data(
            args.model_dir,
            args.output_log_parent_dir,
            args.max_splits,
            args.num_utility_samples,
            args.jax_seed,
            args.target_split_num,
            args.uc_data_subsample_trigger_n,
            args.uc_data_subsample_batch_n,
            args.uc_data_num_subsample_batches,
            args.uc_data_subsampling_seed,
            args.uc_utility_sample_chunk_size,
        )
        print("\nECDF data computation process completed.")
    except Exception as e:
        print(f"Error during ECDF data computation: {str(e)}")
        import traceback

        traceback.print_exc()


if __name__ == "__main__":
    main()
