import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset, Subset
from torchvision import datasets, transforms
from transformers import CLIPModel, CLIPProcessor
from tqdm import tqdm
from eval import *
import random
# import wandb
from torch.utils.data import DataLoader, ConcatDataset, Subset
import numpy as np
from datasets import load_dataset
from torchvision import datasets, transforms

class TinyImageNetDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, idx):
        idx = int(idx)
        image = self.dataset[idx]['image']
        label = self.dataset[idx]['label']
        if self.transform:
            image = self.transform(image)
        return image, label
    def set_label(self, idx, target_label):
        idx = int(idx)  # ensure idx is a standard Python int
        self.dataset[idx]['label'] = target_label

def load_tiny_imagenet_data():
    tiny_imagenet_train = load_dataset('Maysee/tiny-imagenet', split='train')
    tiny_imagenet_test = load_dataset('Maysee/tiny-imagenet', split='valid')
    from torchvision import transforms
    preprocess = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.Grayscale(num_output_channels=3),  # convert all images to 3 channels
        transforms.ToTensor(),
    ])
    # Create training and test dataset objects
    train_dataset = TinyImageNetDataset(tiny_imagenet_train, transform=preprocess)
    test_dataset = TinyImageNetDataset(tiny_imagenet_test, transform=preprocess)
    return train_dataset, test_dataset

# def load_CUB200_dataset():
def load_before(dataset):
    if dataset=='CIFAR-10':
        from torchvision import transforms
        # Load CIFAR-10 dataset; ensure shuffle=False elsewhere to keep order consistent
        transform = transforms.Compose([transforms.ToTensor()])
        train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
        test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    elif dataset=='CIFAR-100':
        from torchvision import transforms
        # Load CIFAR-100 dataset; ensure shuffle=False elsewhere to keep order consistent
        transform = transforms.Compose([transforms.ToTensor()])
        train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
        test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
    elif 'imagenet' in dataset.lower():
        # Load dataset
        train_dataset, test_dataset = load_tiny_imagenet_data()
    return train_dataset, test_dataset
