from torchvision import datasets, transforms
import os

def get_dataset_dict(args, dataset_name, train =True):

    dataroot = os.path.join(os.getcwd(),  "data")
    dataset_name_to_dataset = {}

    if dataset_name == "mnist":
        mnist_dataset = datasets.MNIST(dataroot, train=train,
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))
                            ]),
                            download = True)
        dataset_name_to_dataset[dataset_name] = mnist_dataset

    elif dataset_name == "fmnist":

        fmnist_dataset = datasets.FashionMNIST(dataroot, train=train,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.2861,),
                                                           (0.3530,))]),
                                    download = True)
        
        dataset_name_to_dataset[dataset_name] = fmnist_dataset
    
    elif dataset_name == "SVHN":

        if train:
            SVHN_split_str = "train"
        else:
            SVHN_split_str = "test"
        
        SVHN_dataset = datasets.SVHN(dataroot, split = SVHN_split_str,
                                transform=transforms.Compose([
                                    transforms.Resize((32, 32)),
                                    transforms.ToTensor(),
                                    #  transforms.Grayscale(num_output_channels=3),
                                    transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))]),
                                download = True)

        dataset_name_to_dataset[dataset_name] = SVHN_dataset

    elif dataset_name == "cifar10":
        
        if args.cifar_model == "resnet18":

            cifar10_transform_32 = transforms.Compose([
                    transforms.ToTensor(),  # Convert images to tensor
                    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),  # Normalize with CIFAR-10 statistics
                ])
            cifar10_dataset = datasets.CIFAR10(root=dataroot, train=train,
                                            download=True, transform=cifar10_transform_32)
            
            dataset_name_to_dataset[dataset_name] = cifar10_dataset
        
        
        elif args.cifar_model == "simpleVGG" or args.cifar_model == "squeezenet":

            if train:

                transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),  # Randomly crop to size 32x32
                    transforms.RandomHorizontalFlip(),     # Random horizontal flip
                    transforms.ToTensor(),                 # Convert images to tensors
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),  # Normalize to mean and std of CIFAR-10
                ])

                cifar10_dataset = datasets.CIFAR10(root=dataroot, train=train,
                                            download=True, transform=transform_train)
                
                dataset_name_to_dataset[dataset_name] = cifar10_dataset
            
            else:
                
                cifar10_transform_32 = transforms.Compose([
                    transforms.ToTensor(),  # Convert images to tensor
                    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),  # Normalize with CIFAR-10 statistics
                ])

                cifar10_dataset = datasets.CIFAR10(root=dataroot, train=train,
                                            download=True, transform=cifar10_transform_32)
                
                dataset_name_to_dataset[dataset_name] = cifar10_dataset
        
        else: raise ValueError("For cifar100, model name only accpets resnet10, resnet18, simpleVGG")

    elif dataset_name == "cifar100":

        if "efficient" in args.cifar_model:

            cifar100_transform_224 = transforms.Compose([
                transforms.Resize((224, 224)),  # Resize images to 224x224
                transforms.ToTensor(),  # Convert images to tensor
                transforms.Normalize(mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2761]),  # Normalize with CIFAR-100 statistics
            ])
            cifar100_dataset = datasets.CIFAR100(root=dataroot, train=train,
                                                download=True, transform=cifar100_transform_224)
            
            dataset_name_to_dataset[dataset_name] = cifar100_dataset

        elif "resnet18" in args.cifar_model:

            cifar100_transform_32 = transforms.Compose([
                    transforms.ToTensor(),  # Convert images to tensor
                    transforms.Normalize(mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2761]),  # Normalize with CIFAR-100 statistics
                ])

            cifar100_dataset = datasets.CIFAR100(root=dataroot, train=train,
                                                download=True, transform=cifar100_transform_32)
            dataset_name_to_dataset[dataset_name] = cifar100_dataset

        
        elif "resnet34" in args.cifar_model:

            if train:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),     # Randomly crop to 32x32 with padding
                    transforms.RandomHorizontalFlip(),        # Randomly flip images horizontally
                    transforms.ToTensor(),                    # Convert images to tensor
                    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),  # Normalize using CIFAR-100 mean and std
                ])

                cifar100_dataset = datasets.CIFAR100(root=dataroot, train=train,
                                                download=True, transform=transform_train)
                
                dataset_name_to_dataset[dataset_name] = cifar100_dataset

            else:

                cifar100_transform_32 = transforms.Compose([
                    transforms.ToTensor(),  # Convert images to tensor
                    transforms.Normalize(mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2761]),  # Normalize with CIFAR-100 statistics
                ])

                cifar100_dataset = datasets.CIFAR100(root=dataroot, train=train,
                                                download=True, transform=cifar100_transform_32)
                
                dataset_name_to_dataset[dataset_name] = cifar100_dataset

        else: raise ValueError("For cifar100, model name only accpets efficientnet_b0, resnet18, densenet121")
    else:
        raise ValueError("For the datasetset, we only accpet mnist, fmnist, SVHN, cifar10, cifar100")
    
    return dataset_name_to_dataset

    


