from torch.utils.data import Dataset
import numpy as np
import torch
from sklearn.preprocessing import OneHotEncoder
import random
from imutils import paths
import cv2
from .augments import Augmenter


class IM_data(Dataset):
    def __init__(self, data_pts: dict, train: bool, onehot: bool = True,
                 normalize: bool = True, aug_dict: dict = None, shape=None, label_encoder=None):

        if shape is None:
            shape = [244, 244]

        self.train = train
        if self.train:
            name = 'Training'
            data_pt = data_pts['TRAIN']
        else:
            name = 'Testing'
            data_pt = data_pts['TEST']

        self.onehot = onehot
        self.normalize = normalize
        self.shape = shape

        self.augmenter = Augmenter(aug_dict=aug_dict, shape=shape, train=train)

        self.label_encoder = label_encoder
        self.label_order = None

        self.im_pts = list(paths.list_images(data_pt))
        self.lbl = np.array([int(x.split('/')[-2][1:]) for x in self.im_pts])

        if onehot:
            self.lbl = self.onehot_score(self.lbl)
            self.data_shape = shape, self.lbl.shape[-1]

            # get label order
            labels_all = np.eye(self.lbl.shape[-1])
            labels_decoded = self.label_encoder.inverse_transform(labels_all)
            self.label_order = [labels_all, labels_decoded]
        else:
            self.lbl = self.lbl - 1  # to have classes [0...n]
            self.data_shape = shape, np.unique(self.lbl).size  # lbl shape = 4 for NLLLoss

        print(f'[INFO] {name} dataset loaded')

    def __len__(self) -> int:
        return len(self.im_pts)

    def __getitem__(self, idx) -> (torch.Tensor, torch.Tensor, torch.Tensor, dict):
        image = cv2.imread(self.im_pts[idx])

        if self.normalize:
            image = cv2.normalize(image, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)

        input_tensor = self.augmenter.apply_augs(image=image)

        if self.onehot:
            lbl = torch.Tensor(self.lbl[idx])
        else:
            lbl = torch.Tensor([self.lbl[idx]])

        meta = {'input_tensor': {'tensor': cv2.resize(image, self.shape), 'pt': self.im_pts[idx]},
                'label': {'val': self.lbl[idx]}
                }

        return input_tensor, lbl, meta

    def onehot_score(self, S_tensor) -> np.array:
        if self.label_encoder is None:
            self.label_encoder = OneHotEncoder()
            self.label_encoder.fit(S_tensor.reshape(-1, 1))
        return self.label_encoder.transform(S_tensor.reshape(-1, 1)).toarray().astype(int)


def get_train_test_idx(TRAIN_SIZE: float, ds_path: str, seed: int = None) -> dict:
    print('[INFO] Loading datasets')
    # setup seed
    if seed is not None:
        random.seed(seed)

    # get DS idx
    ind_list = list(range(len(list(paths.list_images(ds_path)))))
    train_idx = random.sample(ind_list, int(TRAIN_SIZE * len(ind_list)))
    test_idx = list(set(ind_list) - set(train_idx))
    return {'train': train_idx, 'test': test_idx}
