import os
import torch

from torchvision.datasets import CIFAR10,CIFAR100,Food101,DTD,OxfordIIITPet,FGVCAircraft,Flowers102
from torch.utils.data.dataset import Subset, random_split


def cifar10(transform):
    root = os.path.expanduser("~/.cache")    
    trainval = CIFAR10(root, download=True, train=True, transform=transform)
    train = Subset(trainval,range(0,45000)) 
    val = Subset(trainval,range(45000,50000))
    test = CIFAR10(root, download=True, train=False, transform=transform)
    return trainval, train, val, test

def cifar100(transform):
    root = os.path.expanduser("~/.cache")    
    trainval = CIFAR100(root, download=True, train=True, transform=transform)
    train = Subset(trainval,range(0,45000)) 
    val = Subset(trainval,range(45000,50000))
    # train,val = random_split(trainval,lengths=[45000,5000],generator=torch.Generator().manual_seed(42))
    test = CIFAR100(root, download=True, train=False, transform=transform)
    return trainval, train, val, test

def food101(transform):
    root = os.path.expanduser("~/.cache")    
    trainval = Food101(root, download=True, split="train", transform=transform)
    train,val = random_split(trainval,lengths=[68175,7575],generator=torch.Generator().manual_seed(42))
    test = Food101(root, download=True, split="test", transform=transform)
    return trainval, train, val, test

def DescribableTextures(transform):
    root = os.path.expanduser("~/.cache")
    train = DTD(root, download=True, split="train", transform=transform)
    val = DTD(root, download=True, split="val", transform=transform)
    trainval = torch.utils.data.ConcatDataset([train, val])
    test = DTD(root, download=True, split="test", transform=transform)
    return trainval, train, val, test

def oxfordpets(transform):
    root = os.path.expanduser("~/.cache")    
    trainval = OxfordIIITPet(root, download=True, split="trainval", transform=transform)
    train,val = random_split(trainval,lengths=[3312,368],generator=torch.Generator().manual_seed(42))
    test = OxfordIIITPet(root, download=True, split="test", transform=transform)
    return trainval, train, val, test

def fgvcaircraft(transform):
    root = os.path.expanduser("~/.cache")
    train= FGVCAircraft(root, download=True, annotation_level='variant', split='train',transform=transform)
    val= FGVCAircraft(root, download=True, annotation_level='variant', split='val',transform=transform)
    trainval= FGVCAircraft(root, download=True, annotation_level='variant', split='trainval',transform=transform)
    test= FGVCAircraft(root, download=True, annotation_level='variant', split='test',transform=transform)
    return trainval, train, val, test

class Warper(object):
    def __init__(self, dataset) -> None:
        super().__init__()
        self.dataset = dataset
    def __len__(self) -> int:
        return len(self.dataset)
    def __getitem__(self, index):
        img,label = self.dataset[index]
        return img,label-1

def flowers102(transform):
    root = os.path.expanduser("~/.cache")
    train = Flowers102(root, download=True, split="train", transform=transform)
    val = Flowers102(root, download=True, split="val", transform=transform)
    trainval = torch.utils.data.ConcatDataset([train, val])
    test = Flowers102(root, download=True, split="test", transform=transform)
    
    train = Warper(train)
    val = Warper(val)
    trainval = Warper(trainval)
    test = Warper(test)
    return trainval, train, val, test
