import random
import torchvision
from torchvision import transforms
from continuum.datasets import birdsnap
from continuum.datasets import sun397
from continuum.datasets import car196
import torch
def Getdataset(args):
    task = args.task
    dir = args.data_dir
    DATASET = None
    test_DATASET = None
    val_DATASET = None
    
    myTransforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
    
    testTransforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

    
    if task =="CIFAR10":
        DATASET = torchvision.datasets.CIFAR10(root = dir, download=False, transform = myTransforms, train=True)
        test_DATASET = torchvision.datasets.CIFAR10(root = dir, download=False, transform = myTransforms, train=False)
        n = len(DATASET)
        i_val = [[] for _ in range(0, 10)]
        for i, (_, label) in enumerate(DATASET):
            i_val[label].append(i)
        n_each_class = int(0.1 * n / 10)
        i_val = [random.sample(indexes, n_each_class) for indexes in i_val]
        i_val = [row[i] for i in range(0, n_each_class) for row in i_val]
        val_DATASET = torch.utils.data.Subset(DATASET, i_val)
        DATASET = torch.utils.data.Subset(DATASET, list(set(range(0, n)).difference(set(i_val))))
        del i_val
        
        n = len(DATASET)
        i_supervised = [[] for _ in range(0, 10)]
        for i, (_, label) in enumerate(DATASET):
            i_supervised[label].append(i)
        n_each_class = int(args.annotation_ratio * n / 10)
        i_supervised = [random.sample(indexes, n_each_class) for indexes in i_supervised]
        i_supervised = [row[i] for i in range(0, n_each_class) for row in i_supervised]
        supervised_DATASET = torch.utils.data.Subset(DATASET, i_supervised)
        unsupervised_DATASET = torch.utils.data.Subset(DATASET, list(set(range(0, n)).difference(set(i_supervised))))
    
    if task =="CIFAR100":
        '''
        DATASET = torchvision.datasets.CIFAR100(root = dir,download=False,transform = myTransforms ,train=True)
        test_DATASET = torchvision.datasets.CIFAR100(root = dir,download=False,transform = myTransforms ,train=False)
        ''' 
        DATASET = torchvision.datasets.CIFAR100(root = dir,download = False, transform = myTransforms, train = True)
        test_DATASET = torchvision.datasets.CIFAR100(root = dir,download = False, transform = testTransforms, train = False)
        n = len(DATASET)
        i_val = [[] for _ in range(0, 100)]
        for i, (_, label) in enumerate(DATASET):
            i_val[label].append(i)
        n_each_class = int(0.1 * n / 100)
        i_val = [random.sample(indexes, n_each_class) for indexes in i_val]
        i_val = [row[i] for i in range(0, n_each_class) for row in i_val]
        val_DATASET = torch.utils.data.Subset(DATASET, i_val)
        DATASET = torch.utils.data.Subset(DATASET, list(set(range(0, n)).difference(set(i_val))))
        del i_val
        

        
    if task =="CIFAR10_noresize":
        DATASET = torchvision.datasets.CIFAR10(root = dir, download=False, transform = torchvision.transforms.ToTensor() ,train=True)
        test_DATASET = torchvision.datasets.CIFAR10(root = dir, download=False, transform = torchvision.transforms.ToTensor() ,train=False)

    if task =="DTD":
      
        DATASET = torchvision.datasets.DTD(root = dir,download=True,transform = myTransforms ,split='train')
        val_DATASET = torchvision.datasets.DTD(root = dir,download=True,transform = testTransforms ,split='val')
        test_DATASET = torchvision.datasets.DTD(root = dir,download=True,transform = testTransforms ,split='test')
        
    if task =="Flowers102":
        DATASET = torchvision.datasets.Flowers102(root = dir,download=False,transform = myTransforms ,split='train')
        val_DATASET = torchvision.datasets.Flowers102(root = dir,download=False,transform = myTransforms ,split='val')
        test_DATASET = torchvision.datasets.Flowers102(root = dir,download=False,transform = myTransforms ,split='test')
        
    if task =="SUN397":
        DATASET = torchvision.datasets.SUN397(root = dir, download=True, transform = myTransforms)
        n = len(DATASET)
        n_each_class = int(0.4 * n / 397)
        
        i_test = [[] for _ in range(0, 397)]
        for i, (_, label) in enumerate(DATASET):
            i_test[label].append(i)
        i_test = [random.sample(indexes, n_each_class) for indexes in i_test]
        i_test = [row[i] for i in range(0, n_each_class) for row in i_test]
        test_DATASET = torch.utils.data.Subset(DATASET, i_test)
        DATASET = torch.utils.data.Subset(DATASET, list(set(range(0, n)).difference(set(i_test))))
        del i_test


    if task =="Caltech101":
        DATASET = torchvision.datasets.Caltech101(root = dir, download=False, transform = myTransforms)
        test_DATASET = torchvision.datasets.Caltech101(root = dir, download=False, transform = myTransforms)
        
    if task =="Food101":
        DATASET = torchvision.datasets.Food101(root = dir,download=False,transform = myTransforms, split='train')
        test_DATASET = torchvision.datasets.Food101(root = dir,download=False,transform = myTransforms, split='test')
        
        n = len(DATASET)
        i_val = [[] for _ in range(0, 101)]
        for i, (_, label) in enumerate(DATASET):
            i_val[label].append(i)
        n_each_class = int(0.1 * n / 101)
        i_val = [random.sample(indexes, n_each_class) for indexes in i_val]
        i_val = [row[i] for i in range(0, n_each_class) for row in i_val]
        val_DATASET = torch.utils.data.Subset(DATASET, i_val)
        DATASET = torch.utils.data.Subset(DATASET, list(set(range(0, n)).difference(set(i_val))))
        del i_val
        
    if task =="ImageNet":
        train_root = dir+'/ImageNet/train'
        val_root = dir+'/ImageNet/val'
        DATASET = torchvision.datasets.ImageFolder(root=train_root,transform=myTransforms)
        val_DATASET = torchvision.datasets.ImageFolder(root=val_root,transform=myTransforms)
    if task =="car196":
        DATASET = torchvision.datasets.StanfordCars(root = dir, download=True, transform = myTransforms, split='train')
        test_DATASET = torchvision.datasets.StanfordCars(root = dir,download=True,transform = myTransforms, split='test')
        n = len(DATASET)
        i_val = [[] for _ in range(0, 196)]
        for i, (_, label) in enumerate(DATASET):
            i_val[label].append(i)
        n_each_class = int(0.15 * n / 196)
        i_val = [random.sample(indexes, n_each_class) for indexes in i_val]
        i_val = [row[i] for i in range(0, n_each_class) for row in i_val]
        val_DATASET = torch.utils.data.Subset(DATASET, i_val)
        DATASET = torch.utils.data.Subset(DATASET, list(set(range(0, n)).difference(set(i_val))))
        del i_val
    if task =="aircraft":
        DATASET = torchvision.datasets.FGVCAircraft(root = dir, download=True, transform = myTransforms, split='train')
        val_DATASET = torchvision.datasets.FGVCAircraft(root = dir,download=True,transform = myTransforms, split='val')
        test_DATASET = torchvision.datasets.FGVCAircraft(root = dir,download=True,transform = myTransforms, split='test')
        
    if task =="Pets":
        DATASET = torchvision.datasets.OxfordIIITPet(root = dir, download=True, transform = myTransforms)
        test_DATASET = torchvision.datasets.OxfordIIITPet(root = dir,download=False,transform = myTransforms, split='test')
        
        n = len(DATASET)
        i_val = [[] for _ in range(0, 37)]
        for i, (_, label) in enumerate(DATASET):
            i_val[label].append(i)
        n_each_class = int(0.1 * n / 37)
        i_val = [random.sample(indexes, n_each_class) for indexes in i_val]
        i_val = [row[i] for i in range(0, n_each_class) for row in i_val]
        val_DATASET = torch.utils.data.Subset(DATASET, i_val)
        DATASET = torch.utils.data.Subset(DATASET, list(set(range(0, n)).difference(set(i_val))))
        del i_val
        
    if task == "VOC2007":
        DATASET = torchvision.datasets.VOCDetection(root = dir,year = '2007',download=True, transform = myTransforms, image_set ='train')
        val_DATASET = torchvision.datasets.VOCDetection(root = dir,year = '2007',download=True,transform = myTransforms, image_set ='val')
        test_DATASET = torchvision.datasets.VOCDetection(root = dir,year = '2007',download=True,transform = myTransforms, image_set ='test')

        
    if task== 'SVHN':
        DATASET = torchvision.datasets.SVHN(root = dir, split='train', transform = myTransforms, download=True)
        test_DATASET = torchvision.datasets.SVHN(root = dir, split='test', transform = myTransforms, download=True)
        
        n = len(DATASET)
        i_val = [[] for _ in range(0, 10)]
        for i, (_, label) in enumerate(DATASET):
            i_val[label].append(i)
        n_each_class = int(0.1 * n / 10)
        i_val = [random.sample(indexes, n_each_class) for indexes in i_val]
        i_val = [row[i] for i in range(0, n_each_class) for row in i_val]
        val_DATASET = torch.utils.data.Subset(DATASET, i_val)
        DATASET = torch.utils.data.Subset(DATASET, list(set(range(0, n)).difference(set(i_val))))
        del i_val
        
    if task=='Country211':
        DATASET = torchvision.datasets.Country211(root = dir, split='train', download=True, transform = myTransforms)
        test_DATASET = torchvision.datasets.Country211(root = dir, split='test', download=True, transform = myTransforms)
        val_DATASET = torchvision.datasets.Country211(root = dir, split='valid', download=True, transform = myTransforms)
        
    if task=='Caltech256':
        DATASET = torchvision.datasets.Caltech256(root = dir, transform = myTransforms, download=True)
        
    if task=='FER2013':
        DATASET = torchvision.datasets.FER2013(root = dir,split='train',transform = myTransforms)
        test_DATASET = torchvision.datasets.FER2013(root = dir,split='test',transform = myTransforms)
        
    return DATASET, test_DATASET, val_DATASET


def get_aug_dataset(args):
    task = args.task
    dir = args.data_dir
    DATASET = None
    test_DATASET = None
    val_DATASET = None

    basic_transforms = [
        transforms.Resize((256, 256)),
        transforms.CenterCrop((224, 224))]

    all_transforms = [
        transforms.RandomResizedCrop(224, scale = (0.2, 1.0)),
        transforms.RandomHorizontalFlip(p = 1),
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
        transforms.RandomGrayscale(p = 1),
        transforms.GaussianBlur(kernel_size = 224 // 20 * 2 + 1, sigma = (0.1, 2.0))]
    all_transforms = [transforms.Compose([*basic_transforms, t, transforms.ToTensor()]) for t in all_transforms]

    if task =="CIFAR10":
        DATASET = [torchvision.datasets.CIFAR10(root = dir, download=False, transform = t, train=True) for t in all_transforms]
        DATASET = torch.utils.data.ConcatDataset(DATASET)
        test_DATASET = [torchvision.datasets.CIFAR10(root = dir, download=False, transform = t, train=False) for t in all_transforms]
        test_DATASET = torch.utils.data.ConcatDataset(test_DATASET)

    
    if task =="CIFAR100":
        DATASET = [torchvision.datasets.CIFAR100(root = dir, download=False, transform = t, train=True) for t in all_transforms]
        DATASET = torch.utils.data.ConcatDataset(DATASET)
        test_DATASET = [torchvision.datasets.CIFAR100(root = dir, download=False, transform = t, train=False) for t in all_transforms]
        test_DATASET = torch.utils.data.ConcatDataset(test_DATASET)
#         n = len(DATASET)
#         i_supervised = [[] for _ in range(0, 100)]
#         for i, (_, label) in enumerate(DATASET):
#             i_supervised[label].append(i)
#         n_each_class = int(args.annotation_ratio * n / 100)
#         i_supervised = [random.sample(indexes, n_each_class) for indexes in i_supervised]
#         i_supervised = [row[i] for i in range(0, n_each_class) for row in i_supervised]
#         supervised_DATASET = torch.utils.data.Subset(DATASET, i_supervised)
#         unsupervised_DATASET = torch.utils.data.Subset(DATASET, list(set(range(0, n)).difference(set(i_supervised))))
        
    if task =="ImageNet":
        train_root=dir+'/ImageNet/train'
        val_root = dir+'/ImageNet/val'
        DATASET = [torchvision.datasets.ImageFolder(root=train_root, transform = t) for t in all_transforms]
        DATASET = torch.utils.data.ConcatDataset(DATASET)
        val_DATASET = [torchvision.datasets.ImageFolder(root=val_root, transform = t) for t in all_transforms]
        val_DATASET = torch.utils.data.ConcatDataset(val_DATASET)

    return DATASET, test_DATASET, val_DATASET
#     return DATASET, supervised_DATASET, unsupervised_DATASET, test_DATASET, val_DATASET
    pass

def Getnumclass(task):
    if task =="CIFAR10" or task =="CIFAR10_noresize":
        numclass=10
    elif task == "SUN397":
        numclass=397
    elif task == "DTD": 
        numclass=47
    elif task == 'Caltech101' or task == 'aircraft' or task == "Flowers102":  # The introduction to the dataset seems to say that in addition to 101 classes there's a class called background
        numclass=102
    elif task == 'CIFAR100':
        numclass=100
    elif task == 'Food101':
        numclass=101
    elif task == 'car196':
        numclass=196
    elif task == 'Pets':
        numclass=37
    elif task == 'VOC2007':
        numclass=20
    elif task == 'ImageNet':
        numclass=1000
    elif task == 'SVHN':
        numclass=10
    elif task == 'Country211':
        numclass=211
    elif task == 'Caltech256':
        numclass=256
    elif task== 'FER2013':
        numclass=7
    return numclass



