import os
import shutil
import torch
import random
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from collections import defaultdict
#
# Set the random seed for reproducibility
seed = 42
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(seed)

train_transforms = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transforms)
test_indices = list(range(len(trainset)))
# Create output directory
output_dir = "./calibration_data/cifar10_calibration_sets"
os.makedirs(output_dir, exist_ok=True)

# Store results
results = {}
# Set the desired number of samples for each subset
sample_sizes = [100, 200, 400, 600, 800, 1000]

for size in sample_sizes:
    # Randomly sample indices
    sampled_indices = random.sample(test_indices, size)

    # Create a subset of the dataset
    sampled_subset = Subset(trainset, sampled_indices)

    # Save the subset as a .pth file
    subset_path = os.path.join(output_dir, f"val_{size}_samples.pth")
    torch.save(sampled_subset, subset_path)

    # Store subset information in the results dictionary
    results[size] = len(sampled_subset)

# Print results
print("Sampled subsets saved:")
for size, count in results.items():
    print(f"{size} samples: {count} samples saved in {output_dir}/val_{size}_samples.pth")

#
# # #####cifar10 imbalanced calibration set
# import torch
# import random
# import os
# import torchvision
# import torchvision.transforms as transforms
# from torch.utils.data import DataLoader, Subset
# from collections import defaultdict

# # Set random seed for reproducibility
# seed = 42
# torch.manual_seed(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# random.seed(seed)

# train_transforms = transforms.Compose([
#         transforms.RandomCrop(32, padding=4),
#         transforms.RandomHorizontalFlip(),
#         transforms.ToTensor(),
#     ])

# trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transforms)

# # Get indices for each class
# class_indices = defaultdict(list)
# for idx, (data, label) in enumerate(trainset):
#     class_indices[label].append(idx)

# # Select 100 samples from class 0, and 1 sample from each of the other classes
# selected_indices = []
# selected_indices.extend(random.sample(class_indices[0], 100))

# # Take 1 sample from each of the other classes
# for label in range(1, 10):
#     selected_indices.extend(random.sample(class_indices[label], 1))

# # Create subset
# subset = Subset(trainset, selected_indices)
# print(len(subset))

# # Save new dataset
# torch.save(subset, './calibration_data/cifar10_calibration_sets/imbalanced_subset100_1.pth')
