import os

import numpy as np
import pandas as pd
import torch
import torch.utils.data as D
import torchvision.transforms as transforms
from PIL import Image
from sklearn.model_selection import train_test_split

from config import KATHER_DIR
from .experiment import Experiment


class KATHER(Experiment):
    def __init__(self, args):
        model_name, num_classes = args['model_name'], args['num_classes']
        model = self.get_model(model_name, args)
        super().__init__(f'cifar{num_classes}', model, num_classes, args['classes'], args['seed'])
        self.dataset = args['dataset']

    def load_data(self):
        classes = os.listdir(KATHER_DIR)
        imgs_paths, labels = [], []
        for label in classes:
            file_names = os.listdir(os.path.join(KATHER_DIR, label))
            for file_name in file_names:
                imgs_paths.append(os.path.join(KATHER_DIR, label, file_name))
                labels.append(label)

        df = pd.DataFrame(data={'img_path': imgs_paths, 'label': labels})
        label_num = {}
        for idx, item in enumerate(np.unique(df.label)):
            label_num[item] = idx

        df['label_num'] = df['label'].apply(lambda x: label_num[x])

        return df

    def get_data_loaders(self, batch_size):
        df = self.load_data()
        train_df, tmp_df = train_test_split(df,
                                            test_size=0.2,
                                            random_state=self.seed,
                                            stratify=df['label'])

        valid_df, test_df = train_test_split(tmp_df,
                                             test_size=0.8,
                                             random_state=self.seed,
                                             stratify=tmp_df['label'])

        transform = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        ds_train = HistologyMnistDS(train_df, transform)
        ds_val = HistologyMnistDS(valid_df, transform, mode='val')
        ds_test = HistologyMnistDS(test_df, transform, mode='test')

        train_loader = D.DataLoader(ds_train, batch_size=batch_size, shuffle=True, num_workers=2)
        val_loader = D.DataLoader(ds_val, batch_size=batch_size, shuffle=False, num_workers=2)
        test_loader = D.DataLoader(ds_test, batch_size=batch_size, shuffle=False, num_workers=2)

        return train_loader, val_loader, test_loader


# from: https://www.kaggle.com/mariazorkaltseva/colorectal-histology-mnist-images-classification
class HistologyMnistDS(D.Dataset):
    def __init__(self, df, transforms, mode='train'):
        self.records = df.to_records(index=False)
        self.transforms = transforms
        self.mode = mode
        self.len = df.shape[0]

    @staticmethod
    def _load_image_pil(path):
        return Image.open(path)

    def __getitem__(self, index):
        path = self.records[index].img_path
        img = self._load_image_pil(path)

        if self.transforms:
            img = self.transforms(img)

        if self.mode in ['train', 'val', 'test']:
            return img, torch.from_numpy(np.array(self.records[index].label_num))
        else:
            return img

    def __len__(self):
        return self.len
