import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader, Dataset

from config import DERMA_DIR
from .experiment import Experiment


class DERMA(Experiment):
    def __init__(self, args):
        model_name, num_classes = args['model_name'], args['num_classes']
        model = self.get_model(model_name, args)
        super().__init__(f'cifar{num_classes}', model, num_classes, args['classes'], args['seed'])
        self.dataset = args['dataset']

    def load_data(self, mode='train'):
        print(f"Loading {mode} data...")
        npz_file = np.load(DERMA_DIR)

        images = npz_file[f'{mode}_images']
        labels = npz_file[f'{mode}_labels']

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[.5], std=[.5]),
            transforms.Resize((32, 32))
        ])

        return DermaMnistDS(images, labels, transform)

    def get_data_loaders(self, batch_size):
        trainset = self.load_data('train')
        valset = self.load_data('val')
        testset = self.load_data('test')

        train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
        val_loader = DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=2)
        test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

        return train_loader, val_loader, test_loader


# https://github.com/MedMNIST/MedMNIST
class DermaMnistDS(Dataset):
    def __init__(self, images, labels, transform):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __getitem__(self, index):
        # img: an array of 1x28x28x28
        img, label = self.images[index], self.labels[index]
        img = Image.fromarray(img)
        label = torch.tensor(label).type(torch.LongTensor).squeeze()

        if self.transform:
            img = self.transform(img)

        return img, label

    def __len__(self):
        return len(self.images)
