import argparse
import pickle
import csv
import os
import torch
import torchvision
import torchvision.transforms as trn
from transformers import set_seed
import matplotlib.pyplot as plt
import numpy as np
import torchvision.models as models
from collections import defaultdict
from tqdm.auto import tqdm
from torchvision.datasets import Places365
import json


try:
    from examples.utils import get_dataset_dir
except ImportError:
    print("Warning: 'examples.utils.get_dataset_dir' not found. CIFAR dataset path may need to be specified via --cifar_data_root.")
    def get_dataset_dir():
        return "./data" 


from torchcp.classification.predictor import SplitPredictor
from torchcp.classification.score import APS, RAPS, SAPS, EnergyAPS, EnergyRAPS, EnergySAPS, LAC, EnergyLAC
from torchvision.models import resnet50





CIFAR10_MODEL_NAMES = [
    "resnet20", "resnet32", "resnet44", "resnet56",
    "vgg11_bn", "vgg13_bn", "vgg16_bn", "vgg19_bn",
    "mobilenetv2_x0_5", "mobilenetv2_x0_75", "mobilenetv2_x1_0", "mobilenetv2_x1_4",
    "shufflenetv2_x0_5", "shufflenetv2_x1_0", "shufflenetv2_x1_5", "shufflenetv2_x2_0",
    "repvgg_a0", "repvgg_a1", "repvgg_a2"
]

CIFAR100_MODEL_NAMES = [ 
    "resnet20", "resnet32", "resnet44", "resnet56",
    "vgg11_bn", "vgg13_bn", "vgg16_bn", "vgg19_bn",
    "mobilenetv2_x0_5", "mobilenetv2_x0_75", "mobilenetv2_x1_0", "mobilenetv2_x1_4",
    "shufflenetv2_x0_5", "shufflenetv2_x1_0", "shufflenetv2_x1_5", "shufflenetv2_x2_0",
    "repvgg_a0", "repvgg_a1", "repvgg_a2"
]

IMBALANCED_CIFAR100_MODEL_NAMES = [
    "resnet50_a0.005", "resnet50_a0.01", "resnet50_a0.02", "resnet50_a0.03"
]

PLACES365_MODEL_NAMES = [
    
    "alexnet", "resnet18", "resnet50", "densenet161",
    
    "vgg16", "googlenet", "resnet152"
]

IMAGENET_TORCHVISION_MODEL_NAMES = sorted(list(set([
    
    "resnet18", "resnet34", "resnet50", "resnet101", "resnet152",
    "resnext50_32x4d", "resnext101_32x8d", 
    "wide_resnet50_2", "wide_resnet101_2",
    
    "vgg11", "vgg11_bn", "vgg13", "vgg13_bn", "vgg16", "vgg16_bn", "vgg19", "vgg19_bn",
    
    "densenet121", "densenet161", "densenet169", "densenet201",
    
    "mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small",
    
    "efficientnet_b0", "efficientnet_b1", "efficientnet_b2", "efficientnet_b3",
    "efficientnet_b4", "efficientnet_b5", "efficientnet_b6", "efficientnet_b7",
    "efficientnet_v2_s", "efficientnet_v2_m", "efficientnet_v2_l",
    
    "vit_b_16", "vit_b_32", "vit_l_16", "vit_l_32", 
    
    "swin_t", "swin_s", "swin_b", 
    
    "squeezenet1_0", "squeezenet1_1",
    
    "shufflenet_v2_x0_5", "shufflenet_v2_x1_0", "shufflenet_v2_x1_5", "shufflenet_v2_x2_0",
    
    "mnasnet0_5", "mnasnet0_75", "mnasnet1_0", 
    
    "regnet_y_400mf", "regnet_y_800mf", "regnet_y_1_6gf", "regnet_y_3_2gf", "regnet_y_8gf",
    "regnet_y_16gf", "regnet_y_32gf", 
    "regnet_x_400mf", "regnet_x_800mf", "regnet_x_1_6gf", "regnet_x_3_2gf", "regnet_x_8gf",
    "regnet_x_16gf", "regnet_x_32gf",
])))


ALL_AVAILABLE_MODEL_NAMES = sorted(list(set(CIFAR10_MODEL_NAMES + CIFAR100_MODEL_NAMES + 
                                            IMBALANCED_CIFAR100_MODEL_NAMES + 
                                            IMAGENET_TORCHVISION_MODEL_NAMES + PLACES365_MODEL_NAMES)))

DEFAULT_MODEL_PER_DATASET = {
    "cifar10": "resnet56",
    "cifar100": "resnet56",
    "imagenet-val": "resnet50",
    "places365-val": "resnet50",  
    "imbalanced_cifar100": "resnet50_a0.01"
}





def setup_environment(seed_value):
    """Sets the random seed for reproducibility."""
    set_seed(seed_value)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    return device

def create_dataloaders(dataset_name, dataset_instance, cal_split_size, test_split_size=None, batch_size=64, num_workers=4, seed=None):
    """Creates calibration and test dataloaders from a dataset instance."""
    
    if test_split_size is None:
        if cal_split_size >= len(dataset_instance):
            raise ValueError(f"Calibration size ({cal_split_size}) must be less than {dataset_name} size ({len(dataset_instance)}) to leave data for testing.")
        split_lengths = [cal_split_size, len(dataset_instance) - cal_split_size]
        if split_lengths[1] == 0:
            raise ValueError(f"Test split for {dataset_name} has 0 samples. Adjust calibration size.")
    else:
        if cal_split_size + test_split_size > len(dataset_instance):
            raise ValueError(f"Sum of calibration ({cal_split_size}) and test ({test_split_size}) sizes cannot exceed {dataset_instance.__class__.__name__} size ({len(dataset_instance)}).")
        split_lengths = [cal_split_size, test_split_size]
        if len(dataset_instance) > sum(split_lengths):
             
             split_lengths = split_lengths + [len(dataset_instance) - sum(split_lengths)]
             cal_dataset, test_dataset, unused_dataset = torch.utils.data.random_split(dataset_instance, split_lengths, generator=torch.Generator().manual_seed(seed))
        else:
            cal_dataset, test_dataset = torch.utils.data.random_split(dataset_instance, split_lengths, generator=torch.Generator().manual_seed(seed))
    cal_dataloader = torch.utils.data.DataLoader(cal_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    
    return cal_dataloader, test_dataloader

def prepare_imagenet_data(val_dir, cal_size, test_size, batch_size, num_workers, imagenet_mean, imagenet_std, seed):
    """Prepares ImageNet validation data."""
    print(f"Using ImageNet validation set from: {val_dir}")
    if not os.path.isdir(val_dir):
        raise FileNotFoundError(f"ImageNet validation directory not found: {val_dir}. Please check --imagenet_val_dir.")
    transform = trn.Compose([trn.Resize(256), trn.CenterCrop(224), trn.ToTensor(), trn.Normalize(mean=imagenet_mean, std=imagenet_std)])
    dataset = torchvision.datasets.ImageFolder(root=val_dir, transform=transform)
    if len(dataset) == 0:
        raise FileNotFoundError(f"No images found in ImageNet validation directory: {val_dir}.")
    return create_dataloaders("ImageNet-val", dataset, cal_size, test_size, batch_size, num_workers, seed)

def _prepare_cifar_data(dataset_cls, dataset_name_str, root_dir, cal_size, test_size, batch_size, num_workers, mean, std, seed, imbalanced_model=None):
    """Helper function to prepare CIFAR datasets."""
    print(f"Using {dataset_name_str} dataset from: {root_dir}")
    transform = trn.Compose([trn.ToTensor(), trn.Normalize(mean=mean, std=std)])

    try:
        dataset = dataset_cls(root=root_dir, train=False, download=True, transform=transform)
    except Exception as e:
        print(f"Error loading {dataset_name_str} dataset from {root_dir}.")
        raise e
    if len(dataset) == 0:
        raise ValueError(f"{dataset_name_str} dataset is empty.")
    
    
    if dataset_name_str =="CIFAR100" and imbalanced_model is not None:
        summary_file = f'../imbalanced_training/cifar100/all_results_summary.json'
        prior = None
        if os.path.exists(summary_file):
            try:
                with open(summary_file, 'r') as f:
                    results = json.load(f)
                samples_per_class = results[imbalanced_model]['samples_per_class']
                samples_per_class = np.array(samples_per_class)
                prior_per_class = list(samples_per_class / np.sum(samples_per_class))
                prior = {i: prior_per_class[i] for i in range(len(prior_per_class))}
            except Exception as e:
                print(f"Error reading prior from {summary_file}: {e}. Proceeding with balanced split.")
        else:
            print(f"File {summary_file} does not exist. Proceeding with balanced split.")

        if prior is not None:
            
            if not hasattr(dataset, 'targets'):
                print("Dataset has no 'targets' attribute; cannot build imbalanced splits. Falling back to balanced split.")
                return create_dataloaders(dataset_name_str, dataset, cal_size, test_size, batch_size, num_workers, seed)

            labels = torch.tensor(dataset.targets, dtype=torch.long)
            generator = torch.Generator().manual_seed(seed if seed is not None else 0)

            total_n = len(dataset)
            cal_size_eff = min(cal_size, total_n)
            test_size_eff = (total_n - cal_size_eff) if test_size is None else min(test_size, total_n - cal_size_eff)

            all_indices = torch.arange(total_n)
            all_weights = torch.tensor([prior[int(lbl)] for lbl in labels], dtype=torch.double)

            
            if cal_size_eff > 0:
                cal_sel = torch.multinomial(all_weights, num_samples=cal_size_eff, replacement=False, generator=generator)
                cal_indices = all_indices[cal_sel]
            else:
                cal_indices = torch.empty(0, dtype=torch.long)

            
            mask = torch.ones(total_n, dtype=torch.bool)
            if cal_indices.numel() > 0:
                mask[cal_indices] = False
            remaining_indices = all_indices[mask]

            
            if test_size_eff > 0:
                rem_labels = labels[remaining_indices]
                rem_weights = torch.tensor([prior[int(lbl)] for lbl in rem_labels], dtype=torch.double)
                test_sel = torch.multinomial(rem_weights, num_samples=test_size_eff, replacement=False, generator=generator)
                test_indices = remaining_indices[test_sel]
            else:
                test_indices = torch.empty(0, dtype=torch.long)

            cal_dataset = torch.utils.data.Subset(dataset, cal_indices.tolist())
            test_dataset = torch.utils.data.Subset(dataset, test_indices.tolist())

            cal_dataloader = torch.utils.data.DataLoader(cal_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
            test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
            print("Created imbalanced calibration and test splits using weighted sampling based on provided prior.")
            return cal_dataloader, test_dataloader

    
    return create_dataloaders(dataset_name_str, dataset, cal_size, test_size, batch_size, num_workers, seed)

def prepare_cifar10_data(root_dir, cal_size, test_size, batch_size, num_workers, cifar10_mean, cifar10_std, seed):
    """Prepares CIFAR10 data."""
    return _prepare_cifar_data(torchvision.datasets.CIFAR10, "CIFAR10", root_dir, cal_size, test_size, batch_size, num_workers, cifar10_mean, cifar10_std, seed)

def prepare_cifar100_data(root_dir, cal_size, test_size, batch_size, num_workers, cifar100_mean, cifar100_std, seed, imbalanced_model=None):
    """Prepares CIFAR100 data."""
    return _prepare_cifar_data(torchvision.datasets.CIFAR100, "CIFAR100", root_dir, cal_size, test_size, batch_size, num_workers, cifar100_mean, cifar100_std, seed, imbalanced_model=imbalanced_model)

def prepare_places365_data(root_dir, cal_size, test_size, batch_size, num_workers, places365_mean, places365_std, download=False, seed=None):
    """Prepares Places365 validation data using torchvision's Places365 dataset class."""
    print(f"Using Places365 dataset from: {root_dir}")
    if not os.path.isdir(root_dir):
        if download:
            print(f"Creating Places365 root directory: {root_dir}")
            os.makedirs(root_dir, exist_ok=True)
        else:
            raise FileNotFoundError(f"Places365 root directory not found: {root_dir}. Please check --places365_root or use --download_places365.")
    
    
    transform = trn.Compose([
        trn.Resize(256),
        trn.CenterCrop(224),
        trn.ToTensor(),
        trn.Normalize(mean=places365_mean, std=places365_std)
    ])
    
    try:
        
        dataset = Places365(
            root=root_dir,
            split='val',
            small=True,  
            transform=transform,
            download=download  
        )
    except Exception as e:
        print(f"Error loading Places365 dataset from {root_dir}: {e}")
        raise
        
    if len(dataset) == 0:
        raise ValueError(f"Places365 dataset is empty.")
    
    return create_dataloaders("Places365-val", dataset, cal_size, test_size, batch_size, num_workers, seed)

def load_model(dataset_name, model_name_str, device):
    """Loads a pretrained model for the specified dataset."""
    print(f"Loading pretrained model '{model_name_str}' for dataset '{dataset_name}'...")
    model = None
    if dataset_name == "imbalanced_cifar100":
        if model_name_str not in IMBALANCED_CIFAR100_MODEL_NAMES:
            raise ValueError(f"Model '{model_name_str}' not in predefined list for Imbalanced CIFAR-100. Available: {IMBALANCED_CIFAR100_MODEL_NAMES}")
        
        print(f'Loading checkpoint from ../CIFAR100-Imb-ckpts/best_model_{model_name_str.split("_")[1]}.pth')
        checkpoint = torch.load(f'../CIFAR100-Imb-ckpts/best_model_{model_name_str.split("_")[1]}.pth')

        
        model = resnet50()
        
        model.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        model.maxpool = torch.nn.Identity()
        model.fc = torch.nn.Linear(model.fc.in_features, 100)

        
        model.load_state_dict(checkpoint['model_state_dict'])
    elif dataset_name == "imagenet-val":
        if model_name_str not in IMAGENET_TORCHVISION_MODEL_NAMES:
            raise ValueError(f"Model '{model_name_str}' not in predefined list for ImageNet. Available: {IMAGENET_TORCHVISION_MODEL_NAMES}")
        try:
            
            model_fn = getattr(models, model_name_str)
            model = model_fn(weights='DEFAULT')
        except AttributeError:
            raise ValueError(f"Torchvision does not have model: {model_name_str}")
        except Exception as e:
            raise RuntimeError(f"Error loading torchvision model {model_name_str}: {e}")

    elif dataset_name == "places365-val":
        if model_name_str not in PLACES365_MODEL_NAMES:
            raise ValueError(f"Model '{model_name_str}' not in predefined list for Places365. Available: {PLACES365_MODEL_NAMES}")
        
        
        def load_places365_model(model_fn, model_name):
            try:
                
                model = model_fn(weights=None)
                model.fc = torch.nn.Linear(model.fc.in_features, 365)
                
                
                state_dict = torch.hub.load_state_dict_from_url(
                    f'http://places2.csail.mit.edu/models_places365/{model_name}_places365.pth.tar',
                    progress=True, check_hash=False
                )
                
                
                if 'state_dict' in state_dict:
                    state_dict = state_dict['state_dict']
                
                
                new_state_dict = {}
                for k, v in state_dict.items():
                    if k.startswith('module.'):
                        k = k[7:]  
                    new_state_dict[k] = v
                    
                model.load_state_dict(new_state_dict, strict=False)
                return model
            except Exception as e:
                raise RuntimeError(f"Error loading Places365 {model_name} model: {e}")
        
        
        model_constructors = {
            "alexnet": models.alexnet,
            "resnet18": models.resnet18,
            "resnet50": models.resnet50,
            "densenet161": models.densenet161
        }
        
        if model_name_str in model_constructors:
            model = load_places365_model(model_constructors[model_name_str], model_name_str)
        else:
            
            raise NotImplementedError(
                f"The Places365 {model_name_str} model requires a local path to the model file. "
                f"Please use one of the PyTorch models: {', '.join(model_constructors.keys())}"
            )

    elif dataset_name == "cifar100":
        if model_name_str not in CIFAR100_MODEL_NAMES:
            raise ValueError(f"Model '{model_name_str}' not in predefined list for CIFAR-100. Available: {CIFAR100_MODEL_NAMES}")
        hub_model_name = f"cifar100_{model_name_str}"
        try:
            model = torch.hub.load("chenyaofo/pytorch-cifar-models", hub_model_name, pretrained=True, trust_repo=True)
        except Exception as e:
            raise RuntimeError(f"Error loading CIFAR-100 model {hub_model_name} from torch.hub: {e}")

    elif dataset_name == "cifar10":
        if model_name_str not in CIFAR10_MODEL_NAMES:
            raise ValueError(f"Model '{model_name_str}' not in predefined list for CIFAR-10. Available: {CIFAR10_MODEL_NAMES}")
        hub_model_name = f"cifar10_{model_name_str}"
        try:
            model = torch.hub.load("chenyaofo/pytorch-cifar-models", hub_model_name, pretrained=True, trust_repo=True)
        except Exception as e:
            raise RuntimeError(f"Error loading CIFAR-10 model {hub_model_name} from torch.hub: {e}")
    else:
        raise ValueError(f"Unsupported dataset for model loading: {dataset_name}")

    model.to(device)
    model.eval()
    print("Model loaded successfully.")
    return model

def evaluate_conformal_method(model, cal_loader, test_loader, alpha, score_fn_instance, temp_cal=1 , device=None, class_conditional=False, prior=None):
    """Evaluates a conformal prediction method."""
    if prior and hasattr(score_fn_instance, 'set_prior'):
        score_fn_instance.set_prior(prior, device)
    else:
        print("WARNING: Prior is not set")
    predictor = SplitPredictor(score_function=score_fn_instance, model=model, alpha=alpha, temperature=temp_cal, class_conditional=class_conditional)
    predictor.calibrate(cal_loader)
    eval_results = predictor.evaluate(test_loader)
    return eval_results['coverage_rate'], eval_results['average_size'], eval_results['CovGap']

def evaluate_conformal_method_with_logits(cal_logits, cal_labels, test_logits, test_labels, alpha, score_fn_instance, device=None, class_conditional=False, prior=None, diff_violation=False):
    """
    Evaluates a conformal prediction method using pre-computed logits.
    
    Args:
        cal_logits (torch.Tensor): Pre-computed calibration logits.
        cal_labels (torch.Tensor): Calibration labels.
        test_logits (torch.Tensor): Pre-computed test logits.
        test_labels (torch.Tensor): Test labels.
        alpha (float): Significance level.
        score_fn_instance: Score function instance.
        device (torch.device): Device to use for computation.
        class_conditional (bool): Whether to use class-conditional conformal prediction.
        prior: Prior for the score function (if applicable).
    
    Returns:
        tuple: (coverage_rate, average_size, covgap)
    """
    if prior and hasattr(score_fn_instance, 'set_prior'):
        score_fn_instance.set_prior(prior, device)
    else:
        print("WARNING: Prior is not set")
    
    
    predictor = SplitPredictor(score_function=score_fn_instance, model=None, alpha=alpha, class_conditional=class_conditional)
    
    
    if device is not None:
        predictor._device = device
    
    
    predictor.calibrate_with_logits(cal_logits, cal_labels, alpha)
    
    
    eval_results = predictor.evaluate_with_logits(test_logits, test_labels, diff_violation=diff_violation)
    
    if diff_violation:
        return eval_results['coverage_rate'], eval_results['average_size'], eval_results['CovGap'],  eval_results['SSCV'], eval_results['VioClasses'], eval_results['EmptySetsPercentage'] , eval_results['MacroCoverageRate'],  eval_results['DiffViolation']
    else:
        return eval_results['coverage_rate'], eval_results['average_size'], eval_results['CovGap'], eval_results['SSCV'], eval_results['VioClasses'], eval_results['EmptySetsPercentage'] , eval_results['MacroCoverageRate']

def extract_logits_from_dataloader(model, dataloader, device=None):
    """
    Extract logits and labels from a dataloader using the model.
    
    Args:
        model (torch.nn.Module): The model to use for inference.
        dataloader (torch.utils.data.DataLoader): The dataloader to process.
        temperature (float): Temperature scaling parameter.
        device (torch.device): Device to use for computation.
    
    Returns:
        tuple: (logits, labels) where both are concatenated tensors.
    """
    if device is None:
        device = next(model.parameters()).device
    
    model.eval()
    logits_list = []
    labels_list = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Extracting logits"):
            inputs, labels = batch[0].to(device), batch[1].to(device)
            
            raw_logits = model(inputs)
            
            logits_list.append(raw_logits.detach())
            labels_list.append(labels)
    
    logits = torch.cat(logits_list, dim=0).float()
    labels = torch.cat(labels_list, dim=0)
    
    return logits, labels

def generate_performance_plot(dataset_name_str, plot_data, method_family_name,
                              alpha_value, target_coverage,
                              output_dir, bar_color, line_style_and_color,
                              conditional_mode_str):
    """Generates a plot comparing methods with different temperatures_energy."""
    plt.figure(figsize=(12, 7))
    x_labels = [item[0] for item in plot_data]
    coverages = [item[1] for item in plot_data]
    avg_sizes = [item[2] for item in plot_data]

    ax1 = plt.subplot(111)
    ax1.bar(x_labels, coverages, color=bar_color, alpha=0.5, label='Coverage Rate')
    min_cov_display = min(0.6, (min(coverages) - 0.05 if coverages else 0.6), target_coverage - 0.05)
    max_cov_display = max(1.0, (max(coverages) + 0.05 if coverages else 1.0), target_coverage + 0.05)
    ax1.set_ylim([min_cov_display, max_cov_display])
    ax1.set_ylabel('Coverage Rate')
    ax1.axhline(y=target_coverage, linestyle='--', color=bar_color, label=f'Target Coverage (1-{alpha_value:.2f})')
    ax1.set_xticks(range(len(x_labels)))
    ax1.set_xticklabels(x_labels, rotation=45, ha="right")
    ax1.set_xlabel('Temperature_energy (or "Without Energy")')

    ax2 = ax1.twinx()
    
    min_idx = None
    if avg_sizes:
        min_val = min(avg_sizes)
        min_idx = avg_sizes.index(min_val)
    for i, val in enumerate(avg_sizes):
        color = 'green' if (min_idx and i == min_idx) else 'black'
        ax2.plot(i, val, marker='o', color=color, markersize=10 if color == 'green' else 5)
        ax2.annotate(f"{val:.2f}", (i, val), textcoords="offset points", xytext=(0,8), ha='center', fontsize=10, color=color if color else 'black')
    
    ax2.plot(range(len(x_labels)), avg_sizes, line_style_and_color, label='Average Set Size', zorder=1)
    ax2.set_ylabel('Average Set Size')
    if avg_sizes:
        ax2.set_ylim([0, max(avg_sizes) * 1.1 + 1])

    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
    plt.title(f'Effect of Temperature_energy on {method_family_name} ({dataset_name_str}) Performance\n({conditional_mode_str}, α = {alpha_value:.2f})')
    plt.tight_layout()

    plot_filename = f"{method_family_name.lower()}_energy_{conditional_mode_str.lower().replace(' ', '_')}_alpha_{alpha_value:.2f}.png"
    plot_path = os.path.join(output_dir, plot_filename)
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved plot: {plot_path}")





def main(args):
    device = setup_environment(args.seed)

    
    model_name_to_load = args.model_name
    if model_name_to_load is None:
        model_name_to_load = DEFAULT_MODEL_PER_DATASET[args.dataset]
        print(f"--model_name not specified, using default for {args.dataset}: {model_name_to_load}")
    else:
        
        valid_models_for_this_dataset = []
        if args.dataset == "cifar10":
            valid_models_for_this_dataset = CIFAR10_MODEL_NAMES
        elif args.dataset == "cifar100":
            valid_models_for_this_dataset = CIFAR100_MODEL_NAMES
        elif args.dataset == "imbalanced_cifar100":
            valid_models_for_this_dataset = IMBALANCED_CIFAR100_MODEL_NAMES
        elif args.dataset == "imagenet-val":
            valid_models_for_this_dataset = IMAGENET_TORCHVISION_MODEL_NAMES
        elif args.dataset == "places365-val":
            valid_models_for_this_dataset = PLACES365_MODEL_NAMES

        if model_name_to_load not in valid_models_for_this_dataset:
            parser_error_message = (
                f"Model '{model_name_to_load}' is not valid for dataset '{args.dataset}'.\n"
                f"Available models for {args.dataset}:\n{', '.join(valid_models_for_this_dataset)}"
            )
            
            raise ValueError(parser_error_message)
        print(f"Using specified model for {args.dataset}: {model_name_to_load}")

    
    model = load_model(args.dataset, model_name_to_load, device)

    all_trials_results = [] 

    for trial_idx in range(args.num_trials):
        current_trial_seed = args.seed + trial_idx
        print(f"\n--- Starting Trial {trial_idx + 1}/{args.num_trials} (Seed: {current_trial_seed}) ---")
        setup_environment(current_trial_seed) 

        
        if args.dataset == "imagenet-val":
            cal_dataloader, test_dataloader = prepare_imagenet_data(
                args.imagenet_val_dir, args.cal_size_imagenet, args.test_size_imagenet, args.batch_size, args.num_workers,
                args.imagenet_mean, args.imagenet_std, current_trial_seed
            )
        elif "cifar100" in args.dataset:
            cal_dataloader, test_dataloader = prepare_cifar100_data(
                args.cifar_data_root, args.cal_size_cifar100, args.test_size_cifar100, args.batch_size, args.num_workers,
                args.cifar100_mean, args.cifar100_std, current_trial_seed, imbalanced_model=args.imbalanced_model
            )
        elif args.dataset == "cifar10":
            cal_dataloader, test_dataloader = prepare_cifar10_data(
                args.cifar_data_root, args.cal_size_cifar10, args.test_size_cifar10, args.batch_size, args.num_workers,
                args.cifar10_mean, args.cifar10_std, current_trial_seed
            )
        elif args.dataset == "places365-val":
            cal_dataloader, test_dataloader = prepare_places365_data(
                args.places365_root, args.cal_size_places365, args.test_size_places365, args.batch_size, args.num_workers,
                args.places365_mean, args.places365_std, args.download_places365, current_trial_seed
            )
        else:
            raise ValueError(f"Unsupported dataset: {args.dataset}") 

        print("-" * 140)
        
        print(f"{'Trial':<6} {'Method':<14} {'Mode':<20} {'Alpha':<10} {'Temperature_energy':<20} {'Temperature_calibration':<24} {'Coverage Rate':<18} {'Average Set Size':<20} {'CovGap':<18} {'SSCV':<18} {'VioClasses':<18} {'EmptySets':<18} {'MacroCoverageRate':<18}")
        print("-" * 140)

        
        conditional_modes_to_run = [True] if args.class_conditional else [False]

        cal_logits, cal_labels = extract_logits_from_dataloader(model, cal_dataloader, device=device)
        test_logits, test_labels = extract_logits_from_dataloader(model, test_dataloader, device=device)
        
        for is_class_conditional in conditional_modes_to_run:
            conditional_mode_str = "Class-Conditional" if is_class_conditional else "Standard"

            for alpha in args.alphas:
                for temp_cal in args.temperatures_calibration:
                    
                    methods_config_standard = [
                        ('LAC', LAC, {}),
                        ('APS', APS, {}),
                        ('RAPS', RAPS, {'penalty': args.raps_penalty, 'kreg': args.raps_kreg}),
                        ('SAPS', SAPS, {'weight': args.saps_weight})
                    ]
                    for method_name, score_class, base_params in methods_config_standard:
                        score_fn_instance = score_class(**base_params)
                        cal_logits_temp = cal_logits / temp_cal
                        test_logits_temp = test_logits / temp_cal
                        if args.diff_violation:
                            cov, size, covgap, sscv, vio_classes, empty_sets, macro_cov, diff_violation = evaluate_conformal_method_with_logits(cal_logits_temp, cal_labels, test_logits_temp, test_labels, alpha, score_fn_instance, device=device, class_conditional=is_class_conditional, diff_violation=True)
                        else:
                            cov, size, covgap, sscv, vio_classes, empty_sets, macro_cov = evaluate_conformal_method_with_logits(cal_logits_temp, cal_labels, test_logits_temp, test_labels, alpha, score_fn_instance, device=device, class_conditional=is_class_conditional, diff_violation=False)
                        print(f"{trial_idx + 1:<6} {method_name:<14} {conditional_mode_str:<20} {alpha:<10.3f} {'N/A':<20} {temp_cal:<24.2f} {cov:<18.4f} {size:<20.4f} {covgap:<18.4f} {sscv:<18.4f} {vio_classes:<18.4f} {empty_sets:<18.4f} {macro_cov:<18.4f}")
                        if args.diff_violation:
                            print("DiffViolation", diff_violation)
                        all_trials_results.append({
                            'Trial': trial_idx + 1,
                            'Method': method_name,
                            'Alpha': alpha,
                            'Temperature_energy': 'N/A',
                            'Temperature_calibration': temp_cal,
                            'ClassConditional': is_class_conditional,
                            'Coverage': cov,
                            'Size': size,
                            'CovGap': covgap,
                            'SSCV': sscv,
                            'VioClasses': vio_classes,
                            'EmptySetsPercentage': empty_sets,
                            'MacroCoverageRate': macro_cov,
                            'DiffViolation': diff_violation if args.diff_violation else None,
                            'OOD': False
                        })

                    
                    energy_methods_config = [
                        ('EnergyLAC', EnergyLAC, {}),
                        ('EnergyAPS', EnergyAPS, {}),
                        ('EnergyRAPS', EnergyRAPS, {'penalty': args.raps_penalty, 'kreg': args.raps_kreg}),
                        ('EnergySAPS', EnergySAPS, {'weight': args.saps_weight})
                    ]
                    if args.imbalanced_model is not None:
                        energy_methods_config.append(('LAC-Imb', LAC, {}))
                        energy_methods_config.append(('APS-Imb', APS, {}))
                        energy_methods_config.append(('RAPS-Imb', RAPS, {'penalty': args.raps_penalty, 'kreg': args.raps_kreg}))
                        energy_methods_config.append(('SAPS-Imb', SAPS, {'weight': args.saps_weight}))
                    for method_name, score_class, base_params in energy_methods_config:

                        if args.imbalanced_model and method_name.endswith('-Imb'):
                            summary_file = f'../imbalanced_training/cifar100/all_results_summary.json'
                            if os.path.exists(summary_file):
                                with open(summary_file, 'r') as f:
                                    results = json.load(f)
                            else:
                                print(f"File {summary_file} does not exist")
                            samples_per_class = results[args.imbalanced_model]['samples_per_class']
                            samples_per_class = np.array(samples_per_class)
                            prior_per_class = list(samples_per_class / np.sum(samples_per_class))
                            prior = {i: prior_per_class[i] for i in range(len(prior_per_class))}
                        else:
                            prior = None

                        for temp_e in args.temperatures_energy:
                            if 'Energy' in method_name:
                                score_fn_params = {**base_params, 'score_type': args.energy_score_type, 'temp_e': temp_e, 'temp_cal': temp_cal}
                            else:
                                score_fn_params = {**base_params, 'score_type': 'softmax'}
                            score_fn_instance = score_class(**score_fn_params)
                            cal_logits_temp = cal_logits / temp_cal
                            test_logits_temp = test_logits / temp_cal
                            if args.diff_violation:
                                cov, size, covgap, sscv, vio_classes, empty_sets, macro_cov, diff_violation = evaluate_conformal_method_with_logits(cal_logits_temp, cal_labels, test_logits_temp, test_labels, alpha, score_fn_instance, device=device, class_conditional=is_class_conditional, diff_violation=True, prior=prior)
                            else:                              
                                cov, size, covgap, sscv, vio_classes, empty_sets, macro_cov = evaluate_conformal_method_with_logits(cal_logits_temp, cal_labels, test_logits_temp, test_labels, alpha, score_fn_instance, device=device, class_conditional=is_class_conditional, diff_violation=False, prior=prior)

                            
                            print(f"{trial_idx + 1:<6} {method_name:<14} {conditional_mode_str:<20} {alpha:<10.3f} {temp_e:<20.2f} {temp_cal:<24.2f} {cov:<18.4f} {size:<20.4f} {covgap:<18.4f} {sscv:<18.4f} {vio_classes:<18.4f} {empty_sets:<18.4f} {macro_cov:<18.4f}")
                            if args.diff_violation:
                                print("DiffViolation", diff_violation)
                            all_trials_results.append({
                                'Trial': trial_idx + 1,
                                'Method': method_name,
                                'Alpha': alpha,
                                'Temperature_energy': temp_e,
                                'Temperature_calibration': temp_cal,
                                'ClassConditional': is_class_conditional,
                                'Coverage': cov,
                                'Size': size,
                                'CovGap': covgap,
                                'SSCV': sscv,
                                'VioClasses': vio_classes,
                                'EmptySetsPercentage': empty_sets,
                                'MacroCoverageRate': macro_cov,
                                'DiffViolation': diff_violation if args.diff_violation else None,
                                'OOD': False
                            })
                            if 'Imb' in method_name:
                                break
                            
                print("-" * 140) 

    print("\n--- All Trials Complete ---")

    
    print("\n--- Summary Statistics Across Trials ---")
    
    grouped_results = defaultdict(lambda: defaultdict(list))
    for result in all_trials_results:
        config_key = (
            result['Method'],
            result['Alpha'],
            result['Temperature_energy'],
            result['Temperature_calibration'],
            result['ClassConditional']
        )
        
        if result['Coverage'] is not None:
            grouped_results[config_key]['Coverage'].append(result['Coverage'])
        if result['Size'] is not None:
            grouped_results[config_key]['Size'].append(result['Size'])
        if result['CovGap'] is not None:
            grouped_results[config_key]['CovGap'].append(result['CovGap'])
        if result['SSCV'] is not None:
            grouped_results[config_key]['SSCV'].append(result['SSCV'])
        if result['VioClasses'] is not None:
            grouped_results[config_key]['VioClasses'].append(result['VioClasses'])
        if result['EmptySetsPercentage'] is not None:
            grouped_results[config_key]['EmptySetsPercentage'].append(result['EmptySetsPercentage'])
        if result['MacroCoverageRate'] is not None:
            grouped_results[config_key]['MacroCoverageRate'].append(result['MacroCoverageRate'])
        if args.diff_violation and result['DiffViolation'] is not None:
            grouped_results[config_key]['DiffViolation'].append(result['DiffViolation'])

    print("-" * 200)
    print(f"{'Method':<14} {'Mode':<20} {'Alpha':<10} {'Temperature_energy':<20} {'Temperature_calibration':<24} {'Coverage Rate (Mean ± Std)':<32} {'Average Set Size (Mean ± Std)':<32} {'CovGap (Mean ± Std)':<30} {'SSCV (Mean ± Std)':<30} {'VioClasses (Mean ± Std)':<30} {'EmptySets (Mean ± Std)':<30} {'MacroCoverageRate (Mean ± Std)':<30}")
    print("-" * 200)

    for config_key in grouped_results.keys():
        method, alpha, temp_e, temp_cal, is_conditional = config_key
        results = grouped_results[config_key]

        mean_cov = np.mean(results['Coverage'])
        std_cov = np.std(results['Coverage'])
        mean_size = np.mean(results['Size'])
        std_size = np.std(results['Size'])
        mean_covgap = np.mean(results['CovGap'])
        std_covgap = np.std(results['CovGap'])
        mean_sscv = np.mean(results['SSCV'])
        std_sscv = np.std(results['SSCV'])
        mean_macro_cov = np.mean(results['MacroCoverageRate'])
        std_macro_cov = np.std(results['MacroCoverageRate'])
        mean_vio_classes = np.mean(results['VioClasses'])
        std_vio_classes = np.std(results['VioClasses'])
        mean_empty_sets = np.mean(results['EmptySetsPercentage'])
        std_empty_sets = np.std(results['EmptySetsPercentage'])

        mode_str = "Class-Conditional" if is_conditional else "Standard"
        temp_e_str = f"{temp_e:.2f}" if temp_e != 'N/A' else 'N/A'

        print(
            f"{method:<14} {mode_str:<20} {alpha:<10.2f} {temp_e_str:<20} {temp_cal:<24.2f} "
            f"{(f'{mean_cov:.4f} ± {std_cov:.4f}'): <32} "
            f"{(f'{mean_size:.4f} ± {std_size:.4f}'): <32} "
            f"{(f'{mean_covgap:.4f} ± {std_covgap:.4f}'): <30} "
            f"{(f'{mean_sscv:.4f} ± {std_sscv:.4f}'): <30} "
            f"{(f'{mean_vio_classes:.4f} ± {std_vio_classes:.4f}'): <30} "
            f"{(f'{mean_empty_sets:.4f} ± {std_empty_sets:.4f}'): <30}"
            f"{(f'{mean_macro_cov:.4f} ± {std_macro_cov:.4f}'): <30} "
        )
    print("-" * 200)

    
    
    mode_suffix = "_class_conditional" if args.class_conditional else "_standard"
    imbalanced_suffix = f"_imb_{args.imbalanced_model}" if args.imbalanced_model is not None else ""
    csv_filename = f"results/csv/{args.dataset}_{model_name_to_load}_energy_results_trials_{args.num_trials}{mode_suffix}{imbalanced_suffix}_multi_prior.csv"
    pickle_filename = f"results/pkl/{args.dataset}_{model_name_to_load}_energy_results_trials_{args.num_trials}{mode_suffix}{imbalanced_suffix}_multi_prior.pkl"
    os.makedirs(os.path.dirname(csv_filename), exist_ok=True)
    os.makedirs(os.path.dirname(pickle_filename), exist_ok=True)

    
    with open(csv_filename, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=all_trials_results[0].keys())
        writer.writeheader()
        writer.writerows(all_trials_results)
    with open(pickle_filename, 'wb') as f:
        pickle.dump(all_trials_results, f)
    print(f"Saved results to {csv_filename} and {pickle_filename}")

    
    if args.enable_plotting:
        if args.num_trials > 1:
            print(f"\nSkipping plotting because num_trials > 1. Plots are typically generated for single runs.")
        else:
            print("\nGenerating plots...")
            os.makedirs(args.plot_output_dir, exist_ok=True)
            plot_configs = {
                'LAC': {'methods': ['LAC', 'EnergyLAC'], 'bar_color': 'r', 'line_style': 'b-'},
                'APS': {'methods': ['APS', 'EnergyAPS'], 'bar_color': 'b', 'line_style': 'r-'},
                'RAPS': {'methods': ['RAPS', 'EnergyRAPS'], 'bar_color': 'g', 'line_style': 'm-'},
                'SAPS': {'methods': ['SAPS', 'EnergySAPS'], 'bar_color': 'y', 'line_style': 'c-'}
            }

            
            single_trial_results = [r for r in all_trials_results if r['Trial'] == 1]

            
            for is_class_conditional in conditional_modes_to_run:
                conditional_mode_str = "Class-Conditional" if is_class_conditional else "Standard"

                for family_name, config in plot_configs.items():
                    
                    family_results = [r for r in single_trial_results if r['Method'] in config['methods'] and r['ClassConditional'] == is_class_conditional]

                    for alpha_val in args.alphas:
                        
                        alpha_specific_results = [r for r in family_results if r['Alpha'] == alpha_val]

                        plot_data_for_alpha = []
                        
                        for res_dict in alpha_specific_results:
                            x_label = "Without Energy" if res_dict['Temperature_energy'] == 'N/A' else str(res_dict['Temperature_energy'])
                            plot_data_for_alpha.append((x_label, res_dict['Coverage'], res_dict['Size'], res_dict['CovGap'], res_dict['DiffViolation'], res_dict['SSCV'], res_dict['VioClasses'], res_dict['EmptySetsPercentage'], res_dict['MacroCoverageRate']))

                        
                        plot_data_for_alpha.sort(key=lambda x: (x[0] != "Without Energy", float(x[0]) if x[0] != "Without Energy" else -float('inf')))

                        if plot_data_for_alpha: 
                             generate_performance_plot(
                                dataset_name_str=args.dataset, plot_data=plot_data_for_alpha,
                                method_family_name=family_name, alpha_value=alpha_val,
                                target_coverage=(1 - alpha_val), output_dir=args.plot_output_dir,
                                bar_color=config['bar_color'], line_style_and_color=config['line_style'],
                                conditional_mode_str=conditional_mode_str
                            )
    elif args.num_trials > 1 and not args.enable_plotting:
         print(f"\nPlotting is disabled (--enable_plotting is False) and num_trials > 1. No plots will be generated.")


def parse_arguments():
    parser = argparse.ArgumentParser(description="Conformal Prediction Method Comparison", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--seed", type=int, default=42, help="Base random seed.")
    parser.add_argument("--num_trials", type=int, default=1, help="Number of trials to run for each configuration.")
    parser.add_argument("--diff_violation", action="store_true", default=False, help="Calculate DiffViolation metric.")

    dataset_group = parser.add_argument_group('Dataset Configuration')
    dataset_group.add_argument("--dataset", type=str, default="cifar100",
                        choices=["cifar10", "cifar100", "imagenet-val", "places365-val", "imbalanced_cifar100"],
                        help="Dataset to use.")
    dataset_group.add_argument("--imbalanced_model", type=str, default=None, help="Use imbalanced model.")

    dataset_group.add_argument("--model_name", type=str, default=None,
                        choices=ALL_AVAILABLE_MODEL_NAMES,
                        help="Model architecture. If None, a default for the dataset is used. "
                             "Ensure the model is compatible with the chosen dataset (validation done at runtime).")
    dataset_group.add_argument("--imagenet_val_dir", type=str, default="",
                        help="Path to ImageNet validation folder.")
    dataset_group.add_argument("--cifar_data_root", type=str, default=get_dataset_dir(),
                        help="Root directory for CIFAR datasets.")
    dataset_group.add_argument("--places365_root", type=str, default="./data/places365",
                        help="Root directory for Places365 dataset (will contain downloaded files if download=True).")
    dataset_group.add_argument("--download_places365", action="store_true", default=False,
                        help="Download Places365 dataset if not available locally.")

    loader_group = parser.add_argument_group('Dataloader Configuration')
    loader_group.add_argument("--batch_size", type=int, default=128, help="Batch size.")
    loader_group.add_argument("--num_workers", type=int, default=4, help="Number of dataloader workers.")

    split_group = parser.add_argument_group('Dataset Split Sizes')
    split_group.add_argument("--cal_size_cifar10", type=int, default=5000, help="Calibration size for CIFAR10 (from 10k test images).") 
    split_group.add_argument("--test_size_cifar10", type=int, default=5000, help="Test size for CIFAR10 (from 10k test images).") 
    split_group.add_argument("--cal_size_cifar100", type=int, default=5000, help="Calibration size for CIFAR100 (from 10k test images).") 
    split_group.add_argument("--test_size_cifar100", type=int, default=5000, help="Test size for CIFAR100 (from 10k test images).") 
    split_group.add_argument("--cal_size_imagenet", type=int, default=20000, help="Calibration size for ImageNet-val (test is remainder).")
    split_group.add_argument("--test_size_imagenet", type=int, default=20000, help="Test size for ImageNet-val (test is remainder).")
    split_group.add_argument("--cal_size_places365", type=int, default=18000, help="Calibration size for Places365-val (test is remainder).")
    split_group.add_argument("--test_size_places365", type=int, default=18500, help="Test size for Places365-val (test is remainder).")

    exp_group = parser.add_argument_group('Experiment Hyperparameters')
    exp_group.add_argument("--alphas", type=float, nargs='+', default=[0.1, 0.05, 0.025, 0.01],  
                        help="Significance levels (alpha).")
    exp_group.add_argument("--temperatures_energy", type=float, nargs='+',
                        default=[0.0002, 0.00026, 0.00034, 0.00043, 0.00055, 0.00071, 0.00091, 0.00117, 0.0015, 0.00193, 0.00248, 0.00318, 0.00409, 0.00525, 0.00674, 0.00865, 0.01111, 0.01426, 0.01832, 0.02352, 0.0302, 0.03877, 0.04979, 0.06393, 0.08208, 0.1054, 0.13534, 0.17377, 0.22313, 0.2865, 0.36788, 0.47237, 0.60653, 0.7788, 1.0, 1.28403, 1.64872, 2.117, 2.71828, 3.49034, 4.48169, 5.7546, 7.38906, 9.48774, 12.18249, 15.64263, 20.08554, 25.79034, 33.11545, 42.52108, 54.59815, 70.10541, 90.01713, 115.58428, 148.41316, 190.56627, 244.69193, 314.19066, 403.42879, 518.01282, 665.14163, 854.05876, 1096.63316, 1408.10485, 1808.04241, 2321.57241, 2980.95799, 3827.62582, 4914.76884, 6310.68811, 8103.08393, 10404.56572],
                        help="Temperatures of logit scaling for energy methods.")
    exp_group.add_argument("--temperatures_calibration", type=float, nargs='+',
                        default=[25, 20, 17.5, 15, 12.5, 10, 7.5, 5, 2.5, 2.0, 1.5, 1.3, 1.0, 0.9, 0.8, 0.75, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.05, 0.025, 0.01], 
                        help="Temperatures of logit scaling for model calibration.")
    exp_group.add_argument("--raps_penalty", type=float, default=0.2, help="Penalty for RAPS.")
    exp_group.add_argument("--raps_kreg", type=int, default=2, help="K_reg for RAPS.")
    exp_group.add_argument("--saps_weight", type=float, default=0.2, help="Weight for SAPS.")
    exp_group.add_argument("--energy_score_type", type=str, default="identity", choices=["identity", "softmax"],
                        help="Score type for EnergyXXX methods.")
    exp_group.add_argument("--class_conditional", action="store_true",
                             help="Enable class-conditional conformal prediction (if set, ONLY class-conditional is run).")
    parser.set_defaults(class_conditional=False) 


    norm_group = parser.add_argument_group('Normalization Constants')
    norm_group.add_argument("--imagenet_mean", type=float, nargs=3, default=[0.485, 0.456, 0.406], help="ImageNet mean.")
    norm_group.add_argument("--imagenet_std", type=float, nargs=3, default=[0.229, 0.224, 0.225], help="ImageNet std.")
    norm_group.add_argument("--cifar10_mean", type=float, nargs=3, default=[0.4914, 0.4822, 0.4465], help="CIFAR10 mean.")
    norm_group.add_argument("--cifar10_std", type=float, nargs=3, default=[0.2023, 0.1994, 0.2010], help="CIFAR10 std.")
    norm_group.add_argument("--cifar100_mean", type=float, nargs=3, default=[0.5071, 0.4867, 0.4408], help="CIFAR100 mean.")
    norm_group.add_argument("--cifar100_std", type=float, nargs=3, default=[0.2675, 0.2565, 0.2761], help="CIFAR100 std.")
    norm_group.add_argument("--places365_mean", type=float, nargs=3, default=[0.485, 0.456, 0.406], help="Places365 mean.")
    norm_group.add_argument("--places365_std", type=float, nargs=3, default=[0.229, 0.224, 0.225], help="Places365 std.")

    plot_group = parser.add_argument_group('Plotting')
    plot_group.add_argument("--enable_plotting", default=False, action="store_true", help="Generate performance plots.")
    plot_group.add_argument("--plot_output_dir", type=str, default="energy_results", help="Directory for plots.")

    return parser.parse_args()

if __name__ == "__main__":
    args = parse_arguments()
    model_name = args.model_name if args.model_name else DEFAULT_MODEL_PER_DATASET[args.dataset]
    
    args.plot_output_dir = f"figures/{args.dataset}/{model_name}"

    try:
        main(args)
    except ValueError as e: 
        print(f"Configuration Error: {e}")
        
        
        