import os
import torch
import math
from PIL import Image
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import Subset, ConcatDataset, DataLoader
import numpy as np
import random

from datasets import load_dataset, Dataset
from sklearn.model_selection import train_test_split

def build_transform(input_size=224, interpolation="bicubic",
                    mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.243, 0.261),
                    crop_pct=0.9):
    def _pil_interp(method):
        if method == "bicubic":
            return Image.BICUBIC
        elif method == "lanczos":
            return Image.LANCZOS
        elif method == "hamming":
            return Image.HAMMING
        else:
            return Image.BILINEAR
    resize_im = input_size > 32
    t = []
    if resize_im:
        size = int(math.floor(input_size / crop_pct))
        ip = _pil_interp(interpolation)
        t.append(
            transforms.Resize(
                size, interpolation=ip
            ),  
        )
        t.append(transforms.CenterCrop(input_size))

    t.append(transforms.ToTensor())
    t.append(transforms.Normalize(mean, std))
    return transforms.Compose(t)


class UTKFaceDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform
        self.label_map = {'White': 0, 'Black': 1, 'Indian': 2, 'Asian': 3, 'Other': 4}
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        image = self.dataset[idx]['image']
        label = self.dataset[idx]['ethnicity']
        label = self.label_map[label]
    
        if self.transform:
            image = self.transform(image)
        
        return image, label
    

def dataset(dataset_name, config):
    
    if dataset_name == 'CIFAR10':
        transform = build_transform(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.243, 0.261), crop_pct=0.9)
        train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
        test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
        
    elif dataset_name == 'CIFAR100':
        transform = build_transform(mean=((0.5071, 0.4867, 0.4408)), std=(0.2675, 0.2565, 0.2761), crop_pct=0.9)
        train_set = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
        test_set = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
    
    elif dataset_name == 'MNIST':
        transform = build_transform(mean=(0.1307,), std=(0.3081,), crop_pct=0.9)
        train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
        test_set = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    elif dataset_name == 'SVHN':
        transform = build_transform(mean=(0.4377, 0.4438, 0.4728), std=(0.1980, 0.2010, 0.1970), crop_pct=0.9)
        train_set = datasets.SVHN(root='./data', split='train', download=True, transform=transform)
        test_set = datasets.SVHN(root='./data', split='test', download=True, transform=transform)
        
    elif dataset_name == 'CINIC10':
        transform = build_transform(mean=(0.47889522, 0.47227842, 0.43047404), std=(0.24205776, 0.23828046, 0.25874835), crop_pct=0.9)
        train_set = datasets.ImageFolder(config.path.data_path+'train', transform)
        test_set =  datasets.ImageFolder(config.path.data_path+'test', transform)
    
    elif dataset_name == 'GTSRB':
        transform = build_transform(mean=(0.3403, 0.3121, 0.3214), std=(0.2724, 0.2608, 0.2669), crop_pct=0.9)
        train_set = datasets.ImageFolder(config.path.data_path+'train', transform)
        test_set = datasets.ImageFolder(config.path.data_path+'test', transform)

    elif dataset_name == 'FER-2013':
        transform = build_transform(mean=(0.5,), std=(0.5,), crop_pct=0.9)
        train_set = datasets.ImageFolder(config.path.data_path+'train', transform)
        test_set = datasets.ImageFolder(config.path.data_path+'test', transform)
        
    elif dataset_name == 'ImageNet100':
        transform = transforms.Compose([transforms.Resize([32, 32], Image.BICUBIC),
                                             transforms.ToTensor(),
                                             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
        full_dataset = ConcatDataset([datasets.ImageFolder(config.path.data_path+'train', transform), 
                                      datasets.ImageFolder(config.path.data_path+'test', transform),
                                      datasets.ImageFolder(config.path.data_path+'val', transform)]) 
    
    elif dataset_name == 'UTK-Face':
        transform = build_transform(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], crop_pct=0.9)
        dataset = load_dataset("nu-delta/utkface")
        ethnicity_labels = np.array(dataset['train']['ethnicity'])
        train_indices = []
        test_indices = []
        ethnicity_classes = np.unique(ethnicity_labels)
        for ethnicity in ethnicity_classes:
            class_indices = np.where(ethnicity_labels == ethnicity)[0]
            train, test = train_test_split(class_indices, test_size=0.2, stratify=ethnicity_labels[class_indices], random_state=42)
            train_indices.extend(train)
            test_indices.extend(test)
        train_dataset = dataset['train'].select(train_indices)
        test_dataset = dataset['train'].select(test_indices)
        train_set = UTKFaceDataset(train_dataset, transform=transform)
        test_set = UTKFaceDataset(test_dataset, transform=transform)
    
    
    train_loader = DataLoader(
        train_set, 
        batch_size=config.learning.train_batch_size, 
        shuffle=True, 
        num_workers=config.learning.num_workers,
        pin_memory=True)
    test_loader = DataLoader(
        test_set, 
        batch_size=config.learning.test_batch_size, 
        shuffle=True, 
        num_workers=config.learning.num_workers,
        pin_memory=True)
    dataloaders = {"train": train_loader, "test": test_loader}
    dataset_sizes = {"train": len(train_set), "test": len(test_set)}
    
    return dataloaders, dataset_sizes
