import numpy as np
import torch
from tqdm.auto import tqdm


class BaseDataset:
    def __init__(self):
        self.data = None
        self.targets = None
        self.path_input = None
        self.type = None
        self.size = None

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

    def load_dataset(self):
        raise NotImplementedError

    def preprocess_dataset(self):
        if self.data.max() == 255:
            self.data = self.data / 255

        assert self.data.min() == 0
        assert self.data.max() == 1

        if torch.is_tensor(self.data):
            self.data = self.data.numpy()
        if torch.is_tensor(self.targets):
            self.targets = self.targets.numpy()

    def remove_duplicates(self):
        _, idx = np.unique(self.data, axis=0, return_index=True)
        self.data = self.data[np.sort(idx)]
        self.targets = self.targets[np.sort(idx)]


class TransformDataset(BaseDataset):
    def __init__(self):
        super().__init__()

    def to_numpy(self, dataset, transform):
        num_images = len(dataset)
        data_tmp = None
        targets_tmp = np.zeros((num_images, ), dtype=np.uint16)

        for i_img, (img, target) in tqdm(enumerate(dataset), total=num_images):
            img = transform(img).numpy()

            if data_tmp is None:
                data_tmp = np.zeros((num_images, *img.shape),
                                    dtype=np.float32)

            data_tmp[i_img] = img
            targets_tmp[i_img] = target

        self.data = data_tmp
        self.targets = targets_tmp
