import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from pycocotools.coco import COCO
import os
import random

def get_loaders_and_stats_for_batch_exp(dataset_name, batch_size=64, data_root='./data'):
    if dataset_name.lower() == 'cifar10':
        transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor()])
        train_dataset = torchvision.datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform)
        test_dataset = torchvision.datasets.CIFAR10(root=data_root, train=False, download=True, transform=transform)
        norm_mean, norm_std, num_channels = [0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010], 3
    elif dataset_name.lower() == 'mnist':
        transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor()])
        train_dataset = torchvision.datasets.MNIST(root=data_root, train=True, download=True, transform=transform)
        test_dataset = torchvision.datasets.MNIST(root=data_root, train=False, download=True, transform=transform)
        norm_mean, norm_std, num_channels = [0.1307], [0.3081], 1
    elif dataset_name.lower() == 'svhn':
        transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor()])
        train_dataset = torchvision.datasets.SVHN(root=data_root, split='train', download=True, transform=transform)
        test_dataset = torchvision.datasets.SVHN(root=data_root, split='test', download=True, transform=transform)
        norm_mean, norm_std, num_channels = [0.4377, 0.4438, 0.4728], [0.1980, 0.2010, 0.1970], 3
    else:
        raise ValueError(f"Dataset '{dataset_name}' not supported.")

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    return train_loader, test_loader, norm_mean, norm_std, num_channels

def get_svhn_for_individual_exp(data_root='./data'):
    transform_train = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4377, 0.4438, 0.4728], std=[0.1980, 0.2010, 0.1970])
    ])
    transform_eval = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor()
    ])
    train_dataset = torchvision.datasets.SVHN(data_root, split='train', download=True, transform=transform_train)
    test_dataset = torchvision.datasets.SVHN(data_root, split='test', download=True, transform=transform_eval)
    return train_dataset, test_dataset

def partition_svhn_for_exp2(test_dataset, n_normal, n_eval):
    all_indices = list(range(len(test_dataset)))
    random.shuffle(all_indices)
    b_normal_indices = all_indices[:n_normal]
    eval_indices = all_indices[n_normal : n_normal + n_eval]
    b_normal_dataset = Subset(test_dataset, b_normal_indices)
    eval_dataset = Subset(test_dataset, eval_indices)
    return b_normal_dataset, eval_dataset
    
def partition_svhn_for_exp2_b(test_dataset, n_normal, n_calib, n_eval_per_group):
    all_indices = list(range(len(test_dataset)))
    random.shuffle(all_indices)
    b_normal_indices = all_indices[:n_normal]
    calib_indices = all_indices[n_normal : n_normal + n_calib]
    eval_start = n_normal + n_calib
    total_eval_samples = 5 * n_eval_per_group
    eval_indices = all_indices[eval_start : eval_start + total_eval_samples]
    b_normal_dataset = Subset(test_dataset, b_normal_indices)
    calib_dataset = Subset(test_dataset, calib_indices)
    eval_datasets = {
        "Natural": Subset(test_dataset, eval_indices[0:n_eval_per_group]),
        "FGSM": Subset(test_dataset, eval_indices[n_eval_per_group : 2*n_eval_per_group]),
        "PGD": Subset(test_dataset, eval_indices[2*n_eval_per_group : 3*n_eval_per_group]),
        "BIM": Subset(test_dataset, eval_indices[3*n_eval_per_group : 4*n_eval_per_group]),
        "AutoAttack": Subset(test_dataset, eval_indices[4*n_eval_per_group : 5*n_eval_per_group])
    }
    return b_normal_dataset, calib_dataset, eval_datasets

def get_sota_loaders(dataset_name, batch_size=128, data_root='./data'):
    if dataset_name.lower() == 'cifar10':
        normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
        transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor(), normalize])
        train_ds = torchvision.datasets.CIFAR10(data_root, train=True, download=True, transform=transform)
        val_ds = torchvision.datasets.CIFAR10(data_root, train=False, download=True, transform=transform)
        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
        val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)
        num_classes = 10
        return train_loader, val_loader, num_classes

    elif dataset_name.lower() == 'coco':
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        transform_train = transforms.Compose([transforms.Resize((256, 256)), transforms.RandomCrop(224), transforms.ToTensor(), normalize])
        transform_eval = transforms.Compose([transforms.Resize((256, 256)), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
        
        coco_root = os.path.join(data_root, 'coco2017')
        ann_train_path = os.path.join(coco_root, "annotations/instances_train2017.json")
        ann_val_path = os.path.join(coco_root, "annotations/instances_val2017.json")
        img_train_path = os.path.join(coco_root, "train2017")
        img_val_path = os.path.join(coco_root, "val2017")

        if not os.path.exists(ann_train_path):
             raise FileNotFoundError(f"COCO annotations not found at {ann_train_path}. Please download the dataset.")

        coco = COCO(ann_train_path)
        cat_ids = coco.getCatIds()
        cat2idx = {cid: i for i, cid in enumerate(cat_ids)}
        num_classes = len(cat_ids)
        
        def target_transform_coco(targets):
            if not targets: return torch.tensor(-1, dtype=torch.long)
            return torch.tensor(cat2idx[targets[0]['category_id']], dtype=torch.long)

        train_ds = torchvision.datasets.CocoDetection(img_train_path, ann_train_path, transform=transform_train, target_transform=target_transform_coco)
        val_ds = torchvision.datasets.CocoDetection(img_val_path, ann_val_path, transform=transform_eval, target_transform=target_transform_coco)

        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=False)
        val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=False)
        return train_loader, val_loader, num_classes
    else:
        raise ValueError(f"Dataset '{dataset_name}' for SOTA comparison not supported.")