import os
import glob
import numpy as np
from torchvision import transforms
from torch.utils.data import Subset
from torch.utils.data import Dataset
from torchvision.datasets.folder import pil_loader
from torchvision.datasets.utils import download_and_extract_archive

from datasets.utils import train_test_split

class TinyImageNet(Dataset):
    def __init__(self, root, train, transform, download=True):

        self.url = "http://cs231n.stanford.edu/tiny-imagenet-200"
        self.root = root

        if download:
            if os.path.exists(f'{self.root}/tiny-imagenet-200/'):
                print('File already downloaded')
            else:
                download_and_extract_archive(self.url, root, filename="tiny-imagenet-200.zip")

        self.root = os.path.join(self.root, "tiny-imagenet-200")
        self.train = train
        self.transform = transform
        self.ids_string = np.sort(np.loadtxt(f"{self.root}/wnids.txt", "str"))
        self.ids = {class_string: i for i, class_string in enumerate(self.ids_string)}

        if train:
            self.paths = glob.glob(f"{self.root}/train/*/images/*")
            self.targets = [self.ids[path.split("/")[-3]] for path in self.paths]
        else:
            self.val_annotations = np.loadtxt(f"{self.root}/val/val_annotations.txt", "str")
            self.paths = [f"{self.root}/val/images/{sample[0]}" for sample in self.val_annotations]
            self.targets = [self.ids[sample[1]] for sample in self.val_annotations]

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        image = pil_loader(self.paths[idx])
        if self.transform is not None:
            image = self.transform(image)
        return image, self.targets[idx]

def get_dataset_from_config(config):
    norm_transform = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    train_transform = transforms.Compose([
        transforms.RandomAffine(degrees=20.0, scale=(0.8, 1.2), shear=20.0),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        norm_transform
    ])
    val_transform = transforms.Compose([
        transforms.ToTensor(),
        norm_transform
    ])

    train_dataset = TinyImageNet(root=config.dataset_root, train=True, download=True,
                                            transform=train_transform)
    valid_dataset = TinyImageNet(root=config.dataset_root, train=True, download=False,
                                        transform=val_transform)
    test_dataset = TinyImageNet(root=config.dataset_root, train=False, download=False,
                                        transform=val_transform)
    
    train_idx, valid_idx = train_test_split(config, train_dataset.targets)

    train_dataset = Subset(train_dataset, train_idx)
    valid_dataset = Subset(valid_dataset, valid_idx)

    return {"train": train_dataset, "valid": valid_dataset, "test": test_dataset}

def is_valid_dataset_name(dataset_name):
    return dataset_name == "tinyimagenet"