import csv, torchvision, numpy as np, random, os
from PIL import Image
import numpy as np
import copy

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets


def load_dataset(name, root, shuffle=True, **kwargs):
    if name in ['imagenet', 'Caltech256', 'tinyimagenet']:
        if name == "Caltech256":
            hw = 224
            mean = [0.5940, 0.5675, 0.5403]
            std = [0.1759, 0.1741, 0.1934]
            transform_train = transforms.Compose([
                transforms.Resize(256),
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5940, 0.5675, 0.5403), (0.1759, 0.1741, 0.1934))
            ])
            transform_test = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize((0.5940, 0.5675, 0.5403), (0.1759, 0.1741, 0.1934))
            ])
            train_val_dataset_dir = os.path.join(root, name, "train", "ss%d"%kwargs["subset_num"], "seed%d"%kwargs["seed"])
            test_dataset_dir = os.path.join(root, name, "val")   
            trainset = datasets.ImageFolder(root=train_val_dataset_dir, transform=transform_train)
            valset   = datasets.ImageFolder(root=test_dataset_dir, transform=transform_test)
            num_classes = 256

        elif name == 'tinyimagenet':
            hw = 32
            mean = [0.485, 0.456, 0.406]
            std = [0.229, 0.224, 0.225]
            transform_train = transforms.Compose([
                transforms.RandomResizedCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
            ])
            transform_test = transforms.Compose([
                transforms.Resize(32),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
            ])

            train_val_dataset_dir = os.path.join(root, name, "train", "ss%d"%kwargs["subset_num"], "seed%d"%kwargs["seed"])
            test_dataset_dir = os.path.join(root, name, "val")

            trainset = datasets.ImageFolder(root=train_val_dataset_dir, transform=transform_train)
            valset   = datasets.ImageFolder(root=test_dataset_dir, transform=transform_test)
            num_classes = 1000

    elif name.startswith('cifar'):
        hw = 32
        mean = [0.4914, 0.4822, 0.4465]
        std = [0.2023, 0.1994, 0.2010]
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        num_classes = 10 if name == 'cifar10' else 100

        train_val_dataset_dir = os.path.join(root, name, "train", "ss%d"%kwargs["subset_num"], "seed%d"%kwargs["seed"])
        test_dataset_dir = os.path.join(root, name, "val")

        trainset = datasets.ImageFolder(root=train_val_dataset_dir, transform=transform_train)
        valset   = datasets.ImageFolder(root=test_dataset_dir, transform=transform_test)
    else:
        raise Exception('Unknown dataset: {}'.format(name))

    mean_image = torch.randn(1, len(mean), hw, hw)
    for ic in range(len(mean)):
        mean_image[0, ic, :, :] = mean_image[0, ic, :, :] * std[ic] + mean[ic]

    trainloader = DataLoader(trainset, batch_size=kwargs["bs"], num_workers=32, shuffle=shuffle, drop_last=kwargs["drop_last"], worker_init_fn=kwargs["worker_init_fn"])
    valloader   = DataLoader(valset,   batch_size=kwargs["bs"], num_workers=32, shuffle=False, worker_init_fn=kwargs["worker_init_fn"])
    return trainloader, valloader, num_classes, hw, mean_image
