import jax
import jax.numpy as jnp
from jax import jit, vmap
from functools import partial
import numpy as np
from scipy.special import softmax as softmax_np # Only for main/testing, not calibrator core
import os
import json
from pathlib import Path
import time
from datetime import datetime
import argparse
from collections import defaultdict
import csv
import torch # Assuming available for data loading as in real_data_calibration.py
from tqdm import tqdm
# --- Import from utility_cal.py (ensure this file is in PYTHONPATH) ---
from utility_cal import (
    _labels_to_one_hot_jax,
    _get_class_wise_util_arrays_jit,
    _compute_ranks_jit, # Import this helper
    calculate_uc_top_class,
    calculate_uc_class_wise,
    calculate_uc_top_k_overall,
)

# --- Data Loading and Splitting (Placeholders/Simplified from previous script) ---
def load_model_data(model_path_str: str):
    """Loads model logits and labels."""
    model_path = Path(model_path_str)
    dataset_name = None
    if "CIFAR100" in model_path.name.upper(): dataset_name = "cifar100"
    elif "CIFAR10" in model_path.name.upper(): dataset_name = "cifar10"
    elif "IMAGENET" in model_path.name.upper(): dataset_name = "imagenet"
    
    if dataset_name is None:
        print(f"Warning: Could not infer standard dataset type from '{model_path.name}'. Attempting to load generic 'logits.npy' and 'labels.npy'.")
        logits_path_npy = model_path / "logits.npy"; labels_path_npy = model_path / "labels.npy"
        if logits_path_npy.exists() and labels_path_npy.exists():
            logits = np.load(logits_path_npy); labels = np.load(labels_path_npy)
            print(f"Loaded generic .npy files: logits shape {logits.shape}, labels shape {labels.shape}")
            return logits, labels
        else: raise FileNotFoundError(f"Could not determine standard dataset type for {model_path_str} and generic 'logits.npy'/'labels.npy' not found.")

    logits_path = model_path / f"{dataset_name}_logits.pt"; labels_path = model_path / f"{dataset_name}_labels.pt"
    if not logits_path.exists() or not labels_path.exists():
        raise FileNotFoundError(f"Could not find {logits_path.name} or {labels_path.name} in {model_path}. Attempted with dataset_name='{dataset_name}'.")
    logits = torch.load(logits_path).numpy(); labels = torch.load(labels_path).numpy()
    return logits, labels

def split_data(logits, labels, val_ratio=0.5, random_seed=42):
    n_samples = len(logits); indices = np.random.RandomState(random_seed).permutation(n_samples)
    val_size = int(n_samples * val_ratio); val_indices, test_indices = indices[:val_size], indices[val_size:]
    return (logits[val_indices], labels[val_indices]), (logits[test_indices], labels[test_indices])

def create_results_directory_structure(model_path_str, method_name):
    results_dir = Path(model_path_str) / "results" / method_name
    results_dir.mkdir(parents=True, exist_ok=True); return results_dir

def save_calibration_outputs(method_results_dir, split_idx, test_probs_np, test_labels_np, metrics_dict):
    split_suffix = f"_{split_idx + 1}"
    np.save(method_results_dir / f"test_probs{split_suffix}.npy", test_probs_np)
    np.save(method_results_dir / f"test_labels{split_suffix}.npy", test_labels_np)
    with open(method_results_dir / f"metrics{split_suffix}.json", 'w') as f:
        serializable_metrics = {k: (v.item() if hasattr(v, 'item') else (v.tolist() if isinstance(v, (np.ndarray, jnp.ndarray)) else v)) for k, v in metrics_dict.items()}
        json.dump(serializable_metrics, f, indent=2)

# --- JAX Helper Functions ---
@jit
def project_to_simplex_jax(v):
    n_features = v.shape[0]
    u = jnp.sort(v)[::-1]; cssv = jnp.cumsum(u) - 1.0; ind = jnp.arange(n_features) + 1
    cond = u - cssv / ind > 0; true_indices = jnp.where(cond, ind, 0); rho = jnp.maximum(jnp.max(true_indices), 1)
    lambda_ = cssv[rho - 1] / rho; return jnp.maximum(v - lambda_, 0)

@jit
def _estimate_uc_error_interval_signed_jit_local(probabilities, vectorized_u_array, true_labels_one_hot):
    n_samples = probabilities.shape[0]
    vec_Y_minus_f_X = true_labels_one_hot - probabilities 
    B_terms = jnp.sum(vec_Y_minus_f_X * vectorized_u_array, axis=1) 
    v_u_values = jnp.sum(probabilities * vectorized_u_array, axis=1) 
    sorted_indices = jnp.argsort(v_u_values)
    sorted_B_terms = B_terms[sorted_indices]; sorted_v_u_values = v_u_values[sorted_indices]
    prefix_sum_B = jnp.concatenate([jnp.array([0.0]), jnp.cumsum(sorted_B_terms)])
    idx_i, idx_j = jnp.triu_indices(n_samples) 
    sum_B_in_intervals = prefix_sum_B[idx_j + 1] - prefix_sum_B[idx_i] 
    abs_sum_B_in_intervals = jnp.abs(sum_B_in_intervals)
    argmax_idx_flat = jnp.argmax(abs_sum_B_in_intervals)
    error_abs = jnp.max(abs_sum_B_in_intervals) / n_samples 
    signed_expectation = sum_B_in_intervals[argmax_idx_flat] / n_samples 
    worst_interval_v_u_min = sorted_v_u_values[idx_i[argmax_idx_flat]]
    worst_interval_v_u_max = sorted_v_u_values[idx_j[argmax_idx_flat]]
    return error_abs, signed_expectation, worst_interval_v_u_min, worst_interval_v_u_max

@partial(jit, static_argnames=["n_classes"])
def _calculate_brier_score_jax(probs, true_labels_one_hot, n_classes):
    return jnp.mean(jnp.sum(jnp.square(probs - true_labels_one_hot), axis=1))
@jit
def _calculate_accuracy_jax(probs, true_labels_indices):
    return jnp.mean(jnp.argmax(probs, axis=1) == true_labels_indices)

def calculate_brier_score_np(probs_np, labels_np, n_classes):
    return np.mean(np.sum(np.square(probs_np - np.eye(n_classes)[labels_np.astype(int)]), axis=1))
def calculate_accuracy_np(probs_np, labels_np):
    return np.mean(np.argmax(probs_np, axis=1) == labels_np)

@partial(jit, static_argnames=["n_classes"])
def _get_top_k_util_arrays_jit_local(probabilities, true_labels_one_hot, k_val: int, n_classes: int):
    n_samples = probabilities.shape[0]
    true_class_indices = jnp.argmax(true_labels_one_hot, axis=1) 
    ranks_for_samples_0idx = _compute_ranks_jit(probabilities, n_classes) 
    vectorized_u_array = (ranks_for_samples_0idx < k_val).astype(probabilities.dtype) 
    realized_utility_array = vectorized_u_array[jnp.arange(n_samples), true_class_indices]
    return realized_utility_array, vectorized_u_array

@partial(jit, static_argnames=["util_type_static", "param_static", "n_cls_static", "n_samp_static"])
def get_patch_vector_direction_jit_standalone(
    f_current_probs, util_type_static: str, param_static: int,
    I_min, I_max, n_cls_static: int, n_samp_static: int,
):
    dummy_thot_for_vec_u = jnp.zeros((n_samp_static, n_cls_static), dtype=f_current_probs.dtype)
    if util_type_static == "class_wise":
         _, vec_u = _get_class_wise_util_arrays_jit(f_current_probs, dummy_thot_for_vec_u, param_static, n_cls_static)
    elif util_type_static == "top_k":
         _, vec_u = _get_top_k_util_arrays_jit_local(f_current_probs, dummy_thot_for_vec_u, param_static, n_cls_static)
    else: vec_u = jnp.zeros_like(f_current_probs)
    v_u_values = jnp.sum(f_current_probs * vec_u, axis=1)
    indicator = jnp.logical_and(v_u_values >= I_min, v_u_values <= I_max).astype(f_current_probs.dtype)
    return vec_u * indicator[:, jnp.newaxis]

@partial(jit, static_argnames=["n_classes"])
def _kernel_process_class_wise(c_idx, f_current_jax, Y_one_hot_jnp, n_classes: int):
    _, current_vectorized_u = _get_class_wise_util_arrays_jit(f_current_jax, Y_one_hot_jnp, c_idx, n_classes)
    err_abs, signed_E, I_min, I_max = _estimate_uc_error_interval_signed_jit_local(f_current_jax, current_vectorized_u, Y_one_hot_jnp)
    return err_abs, signed_E, I_min, I_max, c_idx

@partial(jit, static_argnames=["n_classes"])
def _kernel_process_top_k(k_val, f_current_jax, Y_one_hot_jnp, n_classes: int):
    _, current_vectorized_u = _get_top_k_util_arrays_jit_local(f_current_jax, Y_one_hot_jnp, k_val, n_classes)
    err_abs, signed_E, I_min, I_max = _estimate_uc_error_interval_signed_jit_local(f_current_jax, current_vectorized_u, Y_one_hot_jnp)
    return err_abs, signed_E, I_min, I_max, k_val

def find_worst_patch_standalone(
    f_current_jax, Y_one_hot_jnp, 
    n_classes: int, utility_class_type: str 
):
    max_err_abs_val = -jnp.inf 
    current_best_patch_info = None 
    results_list = [] 

    if utility_class_type == "class_wise" or utility_class_type == "union_cw_topk":
        all_c_indices = jnp.arange(n_classes)
        vmap_kernel_cw = vmap(_kernel_process_class_wise, in_axes=(0, None, None, None), out_axes=0)
        all_err_abs_cw, all_signed_E_cw, all_I_min_cw, all_I_max_cw, all_params_cw = vmap_kernel_cw(
            all_c_indices, f_current_jax, Y_one_hot_jnp, n_classes
        )
        for i in range(n_classes):
            results_list.append((all_err_abs_cw[i], all_signed_E_cw[i], all_I_min_cw[i], all_I_max_cw[i], all_params_cw[i], "class_wise"))

    if utility_class_type == "top_k" or utility_class_type == "union_cw_topk":
        all_k_values = jnp.arange(1, n_classes + 1) 
        vmap_kernel_tk = vmap(_kernel_process_top_k, in_axes=(0, None, None, None), out_axes=0)
        all_err_abs_tk, all_signed_E_tk, all_I_min_tk, all_I_max_tk, all_params_tk = vmap_kernel_tk(
            all_k_values, f_current_jax, Y_one_hot_jnp, n_classes
        )
        for i in range(n_classes): 
            results_list.append((all_err_abs_tk[i], all_signed_E_tk[i], all_I_min_tk[i], all_I_max_tk[i], all_params_tk[i], "top_k"))

    if not results_list: 
        return -1.0, None 

    for res_tuple in results_list:
        err_abs, signed_E, I_min, I_max, param, util_type = res_tuple
        err_abs_py = err_abs.item() if hasattr(err_abs, 'item') else err_abs
        if err_abs_py > max_err_abs_val:
            max_err_abs_val = err_abs_py
            current_best_patch_info = (util_type, param.item(), I_min, I_max, signed_E)
            
    return max_err_abs_val, current_best_patch_info

class IterativeUtilityCalibrator:
    def __init__(
        self,
        utility_class_type="class_wise", 
        max_iters=100,
        n_classes=None,
        inputs_are_probabilities=True, 
        verbose=False, 
        print_every_iters=10,
        n_samples_update_iter=500, 
        update_subsampling_seed_base=0,
        stepsize_type="fixed", 
        fixed_stepsize_value=0.01
    ):
        if n_classes is None: raise ValueError("n_classes must be provided.")
        if utility_class_type not in ["class_wise", "top_k", "union_cw_topk"]:
            raise ValueError(f"Unknown utility_class_type: {utility_class_type}")
        if stepsize_type not in ["dynamic_alpha", "dynamic_C", "fixed"]:
            raise ValueError(f"Unknown stepsize_type: {stepsize_type}")

        self.utility_class_type = utility_class_type; self.max_iters = int(max_iters)
        self.n_classes = int(n_classes); self.inputs_are_probabilities = inputs_are_probabilities
        self.verbose_internal = verbose 
        self.print_every_iters = print_every_iters
        self.n_samples_update_iter = n_samples_update_iter
        self.update_subsampling_seed_base = update_subsampling_seed_base
        self.stepsize_type = stepsize_type
        self.fixed_stepsize_value = fixed_stepsize_value


    def fit(self, X_cal_np: np.ndarray, Y_cal_labels_np: np.ndarray, X_test_np: np.ndarray):
        n_samples_cal_total = X_cal_np.shape[0]
        n_samples_test = X_test_np.shape[0]
        
        X_cal_jnp = jnp.asarray(X_cal_np)
        f_current_jax = jax.nn.softmax(X_cal_jnp, axis=1) if not self.inputs_are_probabilities else X_cal_jnp
        Y_cal_labels_jnp = jnp.asarray(Y_cal_labels_np)
        Y_cal_one_hot_jnp = _labels_to_one_hot_jax(Y_cal_labels_jnp, self.n_classes)

        X_test_jnp = jnp.asarray(X_test_np)
        f_test_jax = jax.nn.softmax(X_test_jnp, axis=1) if not self.inputs_are_probabilities else X_test_jnp
        
        brier_prev_log_step = -1.0 

        if self.verbose_internal: 
            print(f"Input to fit (calibration data) is {'probabilities' if self.inputs_are_probabilities else 'logits, applying softmax'}.")
            brier_cal_initial = _calculate_brier_score_jax(f_current_jax, Y_cal_one_hot_jnp, self.n_classes)
            acc_cal_initial = _calculate_accuracy_jax(f_current_jax, Y_cal_labels_jnp)
            print(f"Initial Cal Set ({n_samples_cal_total} samples): Brier={float(brier_cal_initial):.4f}, Acc={float(acc_cal_initial):.4f}")
            brier_prev_log_step = brier_cal_initial

        for t in tqdm(range(self.max_iters)):
            if self.verbose_internal: print(f"\n--- Iteration {t + 1}/{self.max_iters} ---")
            
            f_iter_jax_for_patch_finding = f_current_jax
            Y_iter_one_hot_jnp_for_patch_finding = Y_cal_one_hot_jnp
            n_samples_for_iter_patch_finding = n_samples_cal_total

            if self.n_samples_update_iter is not None and \
               self.n_samples_update_iter > 0 and \
               self.n_samples_update_iter < n_samples_cal_total:
                
                iter_seed = self.update_subsampling_seed_base + t 
                key_subsample = jax.random.PRNGKey(iter_seed)
                subsample_indices = jax.random.choice(key_subsample, n_samples_cal_total, 
                                                      shape=(self.n_samples_update_iter,), replace=False)
                f_iter_jax_for_patch_finding = f_current_jax[subsample_indices]
                Y_iter_one_hot_jnp_for_patch_finding = Y_cal_one_hot_jnp[subsample_indices]
                n_samples_for_iter_patch_finding = self.n_samples_update_iter
                if self.verbose_internal: print(f"  Using subsample of {n_samples_for_iter_patch_finding} for patch finding.")
            
            max_err_abs_val_t, worst_patch_details = find_worst_patch_standalone(
                f_iter_jax_for_patch_finding, Y_iter_one_hot_jnp_for_patch_finding,
                self.n_classes, self.utility_class_type
            )
            
            if worst_patch_details is None:
                if self.verbose_internal: print("  No patch found. Stopping."); break
            
            util_type_t, param_t, I_min_t, I_max_t, signed_E_t = worst_patch_details
            if self.verbose_internal:
                print(f"  Worst utility (from {'subsample' if n_samples_for_iter_patch_finding < n_samples_cal_total else 'full cal set'}): type={util_type_t}, param={int(param_t)}, Interval=[{float(I_min_t):.3f}, {float(I_max_t):.3f}]")
                print(f"    Signed Expectation E_hat[<Y_sub-f_sub, u*1_I>]: {float(signed_E_t):.6f} (Error Mag: {float(max_err_abs_val_t):.6f})")
            
            if max_err_abs_val_t < 1e-9: 
                if self.verbose_internal: print(f"  Max error magnitude {max_err_abs_val_t:.2e} is below tolerance. Stopping."); break
            
            patch_vector_direction_subsample_t = get_patch_vector_direction_jit_standalone(
                f_iter_jax_for_patch_finding, util_type_t, int(param_t), I_min_t, I_max_t, 
                self.n_classes, n_samples_for_iter_patch_finding,
            )
            
            if self.stepsize_type == "dynamic_alpha":
                R_t_sq_cal = jnp.mean(jnp.sum(jnp.square(patch_vector_direction_subsample_t), axis=1))
                effective_eta_t = signed_E_t / (R_t_sq_cal + 1e-9) 
                if self.verbose_internal: print(f"    Stepsize type: dynamic_alpha, R_t_sq_cal (E_sub[||patch_dir_sub||^2]): {float(R_t_sq_cal):.4g}")
            elif self.stepsize_type == "dynamic_C":
                effective_eta_t = signed_E_t / self.n_classes
                if self.verbose_internal: print(f"    Stepsize type: dynamic_C (denominator: {self.n_classes})")
            elif self.stepsize_type == "fixed":
                effective_eta_t = self.fixed_stepsize_value * jnp.sign(signed_E_t)
                if self.verbose_internal: print(f"    Stepsize type: fixed (value: {self.fixed_stepsize_value})")
            else: 
                raise ValueError(f"Internal error: Unknown stepsize_type {self.stepsize_type}")

            if self.verbose_internal: 
                print(f"    Effective_eta (signed_E_sub/denominator): {float(effective_eta_t):.4g}")
            
            patch_vector_direction_cal_full_t = get_patch_vector_direction_jit_standalone(
                f_current_jax, util_type_t, int(param_t), I_min_t, I_max_t, 
                self.n_classes, n_samples_cal_total,
            )
            f_updated_unprojected_cal = f_current_jax + effective_eta_t * patch_vector_direction_cal_full_t
            f_current_jax = jax.vmap(project_to_simplex_jax)(f_updated_unprojected_cal)

            if n_samples_test > 0:
                patch_vector_direction_test_t = get_patch_vector_direction_jit_standalone(
                    f_test_jax, util_type_t, int(param_t), I_min_t, I_max_t,
                    self.n_classes, n_samples_test,
                )
                f_updated_unprojected_test = f_test_jax + effective_eta_t * patch_vector_direction_test_t
                f_test_jax = jax.vmap(project_to_simplex_jax)(f_updated_unprojected_test)

            if (t + 1) % self.print_every_iters == 0 or t == self.max_iters - 1:
                if self.verbose_internal:
                    brier_cal_current = _calculate_brier_score_jax(f_current_jax, Y_cal_one_hot_jnp, self.n_classes)
                    acc_cal = _calculate_accuracy_jax(f_current_jax, Y_cal_labels_jnp)
                    f_np_cal = np.array(f_current_jax); Y_np_cal = np.array(Y_cal_labels_jnp)
                    
                    brier_decrease_str = ""
                    if brier_prev_log_step > -1.0: 
                        brier_decrease = brier_prev_log_step - brier_cal_current
                        brier_decrease_str = f", Brier Dec: {float(brier_decrease):.4g}" 
                    brier_prev_log_step = brier_cal_current 
                                        
                    uc_tc_cal,_,_,_ = calculate_uc_top_class(f_np_cal, Y_np_cal, use_subsampling=False)
                    uc_cw_cal,_,_,_,_ = calculate_uc_class_wise(f_np_cal, Y_np_cal, use_subsampling=False)
                    uc_tk_cal,_,_,_,_ = calculate_uc_top_k_overall(f_np_cal, Y_np_cal, use_subsampling=False)
                    print(f"  Iter {t+1} Cal Set: Brier={float(brier_cal_current):.4f}{brier_decrease_str}, Acc={float(acc_cal):.4f}, UC-TC={float(uc_tc_cal):.4f}, UC-CW={float(uc_cw_cal):.4f}, UC-TK={float(uc_tk_cal):.4f}")
        
        if self.verbose_internal: print(f"Fitting complete. Test probabilities are transformed.")
        return np.array(f_test_jax) if n_samples_test > 0 else np.array([])

# --- Main Experiment Execution ---
def run_experiments(args):
    np.random.seed(args.random_seed)
    script_verbose = args.verbose 

    if script_verbose: print(f"Loading data from: {args.model_path}")
    all_logits_np, all_labels_np = load_model_data(args.model_path)
    n_total_samples, n_classes = all_logits_np.shape
    if script_verbose: print(f"Loaded {n_total_samples} samples with {n_classes} classes.")

    calibrator_base_config = {
        "max_iters": args.max_iters_cal, 
        "n_classes": n_classes, 
        "inputs_are_probabilities": False, 
        "verbose": script_verbose, 
        "print_every_iters": args.print_every_iters,
        "n_samples_update_iter": args.n_samples_update_iter, 
        "update_subsampling_seed_base": args.random_seed,
        "stepsize_type": args.stepsize_type,
        "fixed_stepsize_value": args.fixed_stepsize_value
    }

    calibrator_configs = {}
    utility_type_map = {
        "cw": ("class_wise", "PostHocUC_CW"),
        "tk": ("top_k", "PostHocUC_TopK"),
        "union": ("union_cw_topk", "PostHocUC_Union")
    }

    for utype_shortcode in args.utility_types: # Use the list from args
        actual_utility_type, name_prefix = utility_type_map[utype_shortcode]
        method_name = f"{name_prefix}_iters{args.max_iters_cal}_step_{args.stepsize_type}"
        if args.stepsize_type == 'fixed':
            method_name += f"_{args.fixed_stepsize_value}"
        if args.n_samples_update_iter is not None:
             method_name += f"_sub{args.n_samples_update_iter}"

        calibrator_configs[method_name] = {**calibrator_base_config, "utility_class_type": actual_utility_type}
        
    aggregated_results = defaultdict(lambda: defaultdict(list))

    for split_idx in range(args.n_splits):
        if script_verbose: print(f"\n--- Running Split {split_idx + 1}/{args.n_splits} ---"); 
        split_seed = args.random_seed + split_idx
        (cal_logits, cal_labels), (test_logits, test_labels) = split_data(all_logits_np, all_labels_np, val_ratio=args.val_ratio, random_seed=split_seed)
        
        if cal_logits.shape[0] == 0 or test_logits.shape[0] == 0:
            print(f"Skipping split {split_idx+1} due to empty calibration or test set after splitting.")
            continue

        X_cal_fit, Y_cal_fit = cal_logits, cal_labels
        if args.n_samples_cal_overall is not None and args.n_samples_cal_overall < len(cal_logits):
            if script_verbose: print(f"Subsampling full calibration data to {args.n_samples_cal_overall} samples for this split.")
            if args.n_samples_cal_overall == 0: 
                print("Warning: n_samples_cal_overall is 0. Skipping fitting for this split.")
                continue 
            cal_indices = np.random.choice(len(cal_logits), args.n_samples_cal_overall, replace=False)
            X_cal_fit, Y_cal_fit = cal_logits[cal_indices], cal_labels[cal_indices]
        
        if X_cal_fit.shape[0] == 0: 
            print("Warning: Calibration set for fit is empty after overall subsampling. Skipping fitting for this split.")
            continue

        if script_verbose: print(f"Cal set for fit: {X_cal_fit.shape[0]}, Test set: {test_logits.shape[0]}")

        uncal_probs = softmax_np(test_logits, axis=1)
        uncal_metrics = evaluate_all_metrics(uncal_probs, test_labels, n_classes, "Uncalibrated")
        print(f"\nSplit {split_idx+1} - Uncalibrated Test Metrics:", {m: f"{v:.4f}" for m,v in uncal_metrics.items()})
        for m, v in uncal_metrics.items(): aggregated_results["Uncalibrated"][m].append(v)
        save_calibration_outputs(create_results_directory_structure(args.model_path, "Uncalibrated"), split_idx, uncal_probs, test_labels, uncal_metrics)

        for method_name, config in calibrator_configs.items():
            if script_verbose: print(f"\n--- Calibrating with: {method_name} ---"); 
            start_time = time.time()
            method_dir = create_results_directory_structure(args.model_path, method_name)
            calibrator = IterativeUtilityCalibrator(**config) 
            
            if script_verbose: print("Fitting calibrator and transforming test data simultaneously...")
            calibrated_probs_test_np = calibrator.fit(X_cal_fit, Y_cal_fit, test_logits) 
            
            if script_verbose: print("\nEvaluating calibrated model on test set...")
            metrics = evaluate_all_metrics(calibrated_probs_test_np, test_labels, n_classes, method_name) 
            print(f"Split {split_idx+1} - {method_name} Test Metrics:", {m: f"{v:.4f}" for m,v in metrics.items()})
            for m, v in metrics.items(): aggregated_results[method_name][m].append(v)
            save_calibration_outputs(method_dir, split_idx, calibrated_probs_test_np, test_labels, metrics)
            if script_verbose: print(f"Finished {method_name} in {time.time() - start_time:.2f}s")

    final_summary = defaultdict(dict)
    for method, metrics_lists in aggregated_results.items():
        for metric, values in metrics_lists.items():
            if not values: 
                mean, std, max_dev = np.nan, np.nan, np.nan
            else:
                mean, std = np.mean(values), np.std(values)
                max_dev = np.max(values) - np.min(values) if len(values) > 1 else 0.0
            final_summary[method][metric] = {"mean": float(mean), "std": float(std), "max_dev": float(max_dev)}
    
    print("\n--- Aggregated Results (Mean ± Std (Max Deviation)) ---")
    all_metrics_names = sorted(list(set(k for v in final_summary.values() for k in v.keys())))
    header = ["Method"] + all_metrics_names; col_w = 45 # Increased col_width for potentially longer method names
    print("Method".ljust(col_w) + " | " + " | ".join(m.ljust(col_w) for m in all_metrics_names))
    print("-" * (col_w + 3 + (col_w + 3) * len(all_metrics_names)-3))
    for method in sorted(final_summary.keys()):
        parts = [method.ljust(col_w)]
        for metric_name in all_metrics_names:
            if metric_name in final_summary[method] and not np.isnan(final_summary[method][metric_name]['mean']):
                stats = final_summary[method][metric_name]
                parts.append(f"{stats['mean']:.4f} ± {stats['std']:.4f} (md: {stats['max_dev']:.4f})".ljust(col_w))
            else: parts.append("N/A".ljust(col_w))
        print(" | ".join(parts))

    summary_csv_filename = f"summary_posthoc_uc_{Path(args.model_path).name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
    summary_csv_path = Path(args.model_path) / "results" / summary_csv_filename
    with open(summary_csv_path, "w", newline="") as f:
        writer = csv.writer(f); writer.writerow(header)
        for method in sorted(final_summary.keys()):
            row = [method]
            for m_name in all_metrics_names:
                if m_name in final_summary[method] and not np.isnan(final_summary[method][m_name]['mean']):
                    s = final_summary[method][m_name]
                    row.append(f"{s['mean']:.4f} +/- {s['std']:.4f} (md: {s['max_dev']:.4f})")
                else: row.append("N/A")
            writer.writerow(row)
    print(f"\nSummary results saved to: {summary_csv_path}")

def evaluate_all_metrics(probs_np, labels_np, n_classes, method_name_for_log=""):
    metrics = {}; use_subsampling_eval = True 
    if probs_np.shape[0] == 0: 
        print(f"Warning: Empty predictions for {method_name_for_log}, returning NaN/default metrics.")
        return {
            "accuracy": np.nan, "brier_score": np.nan, 
            "uc_top_class": np.nan, "uc_class_wise": np.nan, 
            # "uc_class_wise_worst_c": -1, # Removed
            "uc_top_k_overall": np.nan, 
            # "uc_top_k_overall_worst_k": -1 # Removed
        }

    metrics["accuracy"] = calculate_accuracy_np(probs_np, labels_np)
    metrics["brier_score"] = calculate_brier_score_np(probs_np, labels_np, n_classes)
    
    # These functions from utility_cal.py already return the max error over their respective utility classes
    uc_tc_err, _, _, _ = calculate_uc_top_class(probs_np, labels_np, use_subsampling=use_subsampling_eval) 
    metrics["uc_top_class"] = float(uc_tc_err)
    
    uc_cw_err, _, _, _, _ = calculate_uc_class_wise(probs_np, labels_np, use_subsampling=use_subsampling_eval) # worst_c (second return) is ignored
    metrics["uc_class_wise"] = float(uc_cw_err)
    
    uc_tk_err, _, _, _, _ = calculate_uc_top_k_overall(probs_np, labels_np, use_subsampling=use_subsampling_eval) # worst_k (second return) is ignored
    metrics["uc_top_k_overall"] = float(uc_tk_err)
    
    return metrics

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="PostHoc Utility Calibration Experiments")
    parser.add_argument("--model-path",type=str,default="./logits/ViT_Base_P16_224_ImageNet1k",help="Path to model dir.")
    parser.add_argument("--n-splits",type=int,default=5,help="Num random train/test splits.")
    parser.add_argument("--val-ratio",type=float,default=0.7,help="Ratio for calibration set.")
    parser.add_argument("--n-samples-cal-overall",type=int,default=None,help="Num samples for the initial calibration split. Default: all of val_ratio.")
    parser.add_argument("--n-samples-update-iter",type=int,default=500,help="Num samples from cal set for each iteration's update. Default: 500.")
    parser.add_argument("--max-iters-cal",type=int,default=125,help="Max iterations for calibrator.")
    parser.add_argument("--print-every-iters",type=int,default=10,help="Print metrics every N iters in fit.")
    parser.add_argument("--random-seed",type=int,default=42,help="Random seed.")
    parser.add_argument("--verbose",action="store_true",default=False, help="Enable verbose output (quiet by default).") 
    parser.add_argument("--stepsize-type", type=str, default="fixed", choices=["dynamic_alpha", "dynamic_C", "fixed"], help="Type of stepsize calculation.")
    parser.add_argument("--fixed-stepsize-value", type=float, default=0.01, help="Value for fixed stepsize if stepsize_type is 'fixed'.")
    parser.add_argument("--utility-types", nargs='*', default=['union'], choices=['cw', 'tk', 'union'], help="Types of utility classes for calibrator: cw (class-wise), tk (top-k), union (union_cw_topk). Default: ['union'].")

    args = parser.parse_args()
    run_experiments(args)
