from torchvision.datasets import MNIST, FashionMNIST, CIFAR10, CIFAR100, SVHN
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import Dataset, random_split
from sklearn.preprocessing import StandardScaler
import pandas as pd
import torch

def build_dataset(dataset_name='mnist', dataset_dir = '../datasets/'):
    if dataset_name == 'fashionmnist':
        train_dataset, test_dataset = dataset_fashionmnist(dataset_dir)
    elif dataset_name == 'cifar10':
        train_dataset, test_dataset = dataset_cifar10(dataset_dir)
    elif dataset_name == 'cinic10':
        train_dataset, test_dataset = dataset_cinic10(dataset_dir)
    elif dataset_name == 'ham':
        train_dataset, test_dataset = dataset_ham(dataset_dir)
    return train_dataset, test_dataset

def dataset_ham(data_path):
    transform = transforms.Compose([
    transforms.Resize((32,32)),
    # transforms.CenterCrop(224),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
    torch.manual_seed(1234)
    torch.cuda.manual_seed_all(1234)
    torch.backends.cudnn.deterministic = True

    full_dataset = datasets.ImageFolder(data_path + 'base_dir/train_dir', transform=transform)

    total_size = len(full_dataset)
    train_size = int(total_size * 0.8)
    test_size = total_size - train_size

    train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
    return train_dataset, test_dataset


def dataset_fashionmnist(data_path):
    transform_train = transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.2860,), (0.3530,))])
    
    transform_test = transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.2860,), (0.3530,)),
                        ])

    train_dataset = FashionMNIST(root=data_path, train=True, download=True, transform=transform_train)
    test_dataset = FashionMNIST(root=data_path, train=False, download=True, transform=transform_test)

    return train_dataset, test_dataset


def dataset_cifar10(data_path):
    mean = [0.49139968, 0.48215827, 0.44653124]
    std = [0.24703233, 0.24348505, 0.26158768]

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    
    train_dataset = CIFAR10(root=data_path, train=True, download=True, transform=transform_train)
    test_dataset = CIFAR10(root=data_path, train=False, download=True, transform=transform_test)

    return train_dataset, test_dataset


def dataset_cinic10(data_path):
    cinic_mean = [0.47889522, 0.47227842, 0.43047404]
    cinic_std = [0.24205776, 0.23828046, 0.25874835]
    train_dataset = datasets.ImageFolder('{}/{}'.format(data_path, '/CINIC-10/train/'), transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=cinic_mean,std=cinic_std)]))
    test_dataset = datasets.ImageFolder('{}/{}'.format(data_path, '/CINIC-10/test/'), transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=cinic_mean,std=cinic_std)]))
    
    return train_dataset, test_dataset


   
    
    
