import os
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from zipfile import ZipFile

import numpy as np
import spacy
import torch
import torchvision
from matplotlib.image import imread
from sklearn.model_selection import train_test_split
from torch import Tensor
from torch.distributions import Beta
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from torch.utils.data.dataloader import default_collate
from torchtext.data.datasets_utils import _RawTextIterableDataset
from torchtext.datasets import IMDB as torch_IMDB
from torchtext.vocab import GloVe

from puupl.lib.utils import boolean_nested_and

# when Yann's website is down
# https://github.com/pytorch/vision/issues/3549
torchvision.datasets.MNIST.resources = [(
    'https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz',
    'f68b3c2dcbeaaa9fbdd348bbdeb94873'
), (
    'https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz',
    'd53e105ee54ea40749a09fcbcd1e9432'
), (
    'https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz',
    '9fb629c4189551a2d022fa330f9573f3'
), (
    'https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz',
    'ec29112dd5afa0611ce80d1b7f02629c'
)]


def unlabel_positives(y: Tensor, keep_positives: int, seed: Optional[int]) -> Tensor:
    """
    Keeps a random (with seed) subset of positives and set all others to negative
    returns a boolean mask indicating the subset of positives
    """

    if y.sum() < keep_positives:
        raise ValueError('more positive samples requested than originally available')
    elif y.sum() == keep_positives:
        return y.clone()

    pos_mask = y == 1
    new_mask = torch.ones(int(pos_mask.sum().item())).bool()
    new_mask[keep_positives:] = False
    new_mask = new_mask[
        np.random.default_rng(seed).permutation(len(new_mask))
    ]

    y_new = boolean_nested_and(pos_mask, new_mask)
    assert y_new.sum() == keep_positives

    return y_new


class PuDataset(Dataset, ABC):
    """
    Abstract base class for dataset objects with some boilerplate included.
    """

    x: Tensor  # network inputs
    y: Tensor  # network targets
    t: Tensor  # true labels
    p: Tensor  # indicator for original positives
    l: Tensor  # indicator for pseudo-labeled samples
    w: Tensor  # sample weights
    collate_functions: Dict[str, Callable[[List[Tensor]], Tensor]] = {}
    relative_folder: str

    def __init__(self, part: str, n_labeled: Optional[int],
                 val_size: Optional[Union[int, float]],
                 seed: int, relative_folder: Optional[str] = None,
                 mixup_beta_param: Optional[float] = None,
                 base_folder: str = './data/',
                 **kwargs: Any) -> None:

        self.base_folder = base_folder
        self.part = part
        self.n_labeled = n_labeled
        self.val_size = val_size
        self.seed = seed
        self.mixup_beta_param = mixup_beta_param
        self.mixup_weights: Optional[np.array] = None
        self.mixup_indices: Optional[np.array] = None

        if relative_folder is not None:
            self.relative_folder = relative_folder

        if not os.path.exists(self.folder):
            os.makedirs(self.folder)
            self.download()

        x, t, p = self.load_dataset()
        self.x = x
        self.t = t.bool()
        self.p = p.bool()
        self.y = self.p.clone().float()
        if self.part == 'test':
            # labeled test set to compute cross-entropy loss
            self.l = torch.ones_like(self.p).bool()
        else:
            self.l = torch.zeros_like(self.p).bool()
        self.w = torch.ones_like(self.y)

    @property
    def folder(self) -> str:
        return os.path.join(self.base_folder, self.relative_folder)

    def initialize_and_enable_mixup(self) -> None:
        """
        Initializes mixup data structures for the next epoch.
        Call at the beginning of each epoch.
        """
        if self.mixup_beta_param is not None:
            self.mixup_weights = Beta(
                torch.FloatTensor([self.mixup_beta_param]),
                torch.FloatTensor([self.mixup_beta_param])
            ).sample([len(self.x)]).view(-1)
            self.mixup_indices = np.arange(len(self.x))
            np.random.shuffle(self.mixup_indices)

    def disable_mixup(self) -> None:
        if self.mixup_beta_param is not None:
            self.mixup_weights = self.mixup_indices = None

    @abstractmethod
    def download(self) -> None:
        """
        Download the dataset from the net and save it locally.
        """

    @abstractmethod
    def load_dataset(self) -> Tuple[Any, Tensor, Tensor]:
        """
        Load the dataset from the local storage and return a tuple with
        inputs, true labels, and indicator of original positives
        """

    def size(self) -> int:
        return self.x.shape[0]

    def __len__(self) -> int:
        return len(self.x)

    def __getitem__(self, index: int) -> Dict[str, Tensor]:
        """
        Each sample is associated with five things:
         - x: the input to the network
         - y: the pseudo-label
         - t: the true label
         - p: True iff the sample was in the original set of labeled positives
         - l: True iff the sample was assigned a pseudo-label in some time in the past
         - w: scalar weight for each sample
        """

        if self.mixup_indices is None:
            x = self.transform(self.x[index])
            y = self.y[index]
        else:
            assert self.mixup_weights is not None
            w = self.mixup_weights[index]

            other_index = self.mixup_indices[index]
            x1, x2 = self.transform(self.x[index]), self.transform(self.x[other_index])
            y1, y2 = self.y[index], self.y[other_index]

            x = w * x1 + (1 - w) * x2
            y = w * y1 + (1 - w) * y2

        return {
            'x': x,
            'y': y,
            't': self.t[index],
            'p': self.p[index],
            'l': self.l[index],
            'w': self.w[index],
        }

    def transform(self, x: Any) -> Tensor:
        """
        Override to apply transformations to x.
        """
        return x.float()

    def supervised_collate_fn(self, items: List[Dict[str, Tensor]]) -> Dict[str, Tensor]:
        """
        This combines several items coming from __getitem__ into a single batch.
        """

        ts: Dict[str, List[Tensor]] = {}
        for b in items:
            for k, v in b.items():
                if k not in ts:
                    ts[k] = []
                ts[k].append(v)

        return {
            k: self.collate_functions.get(k, default_collate)(v)
            for k, v in ts.items()
        }


class MNIST(PuDataset):
    """
    synthetic dataset sampled from MNIST following procedure in
    https://arxiv.org/pdf/2006.11280.pdf
    """

    relative_folder = 'MNIST'

    def load_dataset(self) -> Tuple[Tensor, Tensor, Tensor]:
        """
        creates a positive-unlabeled dataset based on MNIST
        returns a tuple of (sample, true label, is-positive indicator)
        """

        xx = torch.tensor(np.load(f'{self.folder}/X_train.npy'))
        tt = torch.tensor(np.load(f'{self.folder}/Y_train.npy'))
        tt = (tt % 2 == 1)  # re-label: positives = odd numbers

        x_train, x_val, t_train, t_val = train_test_split(
            xx, tt,
            test_size=self.val_size,
            random_state=self.seed
        )

        assert self.n_labeled is not None
        p_train = unlabel_positives(t_train, self.n_labeled, self.seed)
        p_val = unlabel_positives(t_val, int(len(t_val) * self.n_labeled / len(t_train)),
                                  self.seed + 55)

        # Test
        x_test = torch.tensor(np.load(f'{self.folder}/X_test.npy'))
        t_test = torch.tensor(np.load(f'{self.folder}/Y_test.npy'))
        t_test = (t_test % 2 == 1).float()
        p_test = t_test > 0.5

        print('MNIST - total train pos: %d - total train neg: %d - total test: %d' % (
            torch.sum(tt).item(), torch.sum(1 - tt.int()).item(), len(x_test)))
        print('MNIST - labeled training positives: %d - unlabeled training samples: %d' % (
            torch.sum(p_train).item(), torch.sum(1 - p_train.int()).item()
        ))

        if self.part == 'train':
            return x_train, t_train, p_train
        elif self.part == 'val':
            return x_val, t_val, p_val
        elif self.part == 'test':
            return x_test, t_test, p_test
        else:
            raise ValueError(f'unknown part {self.part}')

    def download(self) -> None:
        trainset = torchvision.datasets.MNIST(
            root=self.folder, train=True, download=True, transform=None)
        testset = torchvision.datasets.MNIST(
            root=self.folder, train=False, download=True, transform=None)

        X_train = trainset.data.unsqueeze(1).float()
        X_test = testset.data.unsqueeze(1).float()
        Y_train = trainset.targets
        Y_test = testset.targets

        X_mean, X_std = X_train.mean(), X_train.std()
        X_train = (X_train - X_mean) / X_std
        X_test = (X_test - X_mean) / X_std

        np.save(file=f"{self.folder}/X_train.npy", arr=X_train)
        np.save(file=f"{self.folder}/X_test.npy", arr=X_test)
        np.save(file=f"{self.folder}/Y_train.npy", arr=Y_train)
        np.save(file=f"{self.folder}/Y_test.npy", arr=Y_test)


class CIFAR10(PuDataset):
    """
    synthetic dataset sampled from CIFAR10 following procedure in
    https://arxiv.org/pdf/2006.11280.pdf
    """

    relative_folder = 'cifar10'

    def load_dataset(self) -> Tuple[Tensor, Tensor, Tensor]:
        xx = torch.tensor(np.load(f'{self.folder}/X_train.npy'))
        tt = np.load(f'{self.folder}/Y_train.npy')

        # re-label: vehicles = positives, animals = negatives
        # see: https://www.cs.toronto.edu/~kriz/cifar.html for mapping of classes to idx
        pos_classes = [0, 1, 8, 9]
        tt = torch.tensor([t in pos_classes for t in tt]).float()

        x_train, x_val, t_train, t_val = train_test_split(
            xx, tt,
            test_size=self.val_size,
            random_state=self.seed
        )

        assert self.n_labeled is not None
        p_train = unlabel_positives(t_train, self.n_labeled, self.seed)
        p_val = unlabel_positives(t_val, int(len(t_val) * self.n_labeled / len(t_train)),
                                  self.seed + 55)

        # Test
        x_test = torch.tensor(np.load(f'{self.folder}/X_test.npy'))
        t_test = np.load(f'{self.folder}/Y_test.npy')
        t_test = torch.tensor([t in pos_classes for t in t_test])
        p_test = t_test > 0.5

        print('CIFAR-10 - total train pos: %d - total train neg: %d - total test: %d' % (
            torch.sum(tt).item(), torch.sum(1 - tt.int()).item(), len(x_test)))
        print('CIFAR10 - labeled training samples: %d - unlabeled training samples: %d' % (
            torch.sum(p_train).item(), torch.sum(1 - p_train.int()).item()))

        if self.part == 'train':
            return x_train, t_train, p_train
        elif self.part == 'val':
            return x_val, t_val, p_val
        elif self.part == 'test':
            return x_test, t_test, p_test
        else:
            raise ValueError(f'unknown part {self.part}')

    def download(self) -> None:
        trainset = torchvision.datasets.CIFAR10(
            root=self.folder, train=True, download=True, transform=None)
        testset = torchvision.datasets.CIFAR10(
            root=self.folder, train=False, download=True, transform=None)

        X_train = trainset.data.swapaxes(1, 3)
        X_test = testset.data.swapaxes(1, 3)
        Y_train = trainset.targets
        Y_test = testset.targets

        X_mean, X_std = X_train.mean(axis=(0, 1, 2)), X_train.std(axis=(0, 1, 2))
        X_train = (X_train - X_mean) / X_std
        X_test = (X_test - X_mean) / X_std

        np.save(file=f"{self.folder}/X_train.npy", arr=X_train)
        np.save(file=f"{self.folder}/X_test.npy", arr=X_test)
        np.save(file=f"{self.folder}/Y_train.npy", arr=Y_train)
        np.save(file=f"{self.folder}/Y_test.npy", arr=Y_test)


class Skewed_CIFAR10(CIFAR10):
    relative_folder = 'cifar10'

    def __init__(self, positive_probs: Optional[Dict[int, float]] = None,
                 *args: Any, **kwargs: Any):
        self.positive_probs = {
            k: v for k, v in positive_probs.items() if v > 0
        } if positive_probs is not None else {
            0: 0.5, 1: 0.3, 8: 0.15, 9: 0.05
        }
        super().__init__(*args, **kwargs)

    def load_dataset(self) -> Tuple[Tensor, Tensor, Tensor]:
        # re-label: vehicles = positives, animals = negatives
        # see: https://www.cs.toronto.edu/~kriz/cifar.html for mapping of classes to idx
        if self.part == 'test':
            # Test
            x_test = torch.tensor(np.load(f'{self.folder}/X_test.npy'))
            t_test = np.load(f'{self.folder}/Y_test.npy')
            t_test = torch.tensor([t in self.positive_probs.keys() for t in t_test])
            p_test = t_test > 0.5
            return x_test, t_test, p_test

        xx = torch.tensor(np.load(f'{self.folder}/X_train.npy'))
        tt = np.load(f'{self.folder}/Y_train.npy')

        probs = torch.tensor([  # probability of each example to be labeled
            self.positive_probs.get(t, 0.0)
            for t in tt
        ]).float()
        tt = torch.tensor([  # true label
            t in self.positive_probs.keys() for t in tt
        ]).float()

        train_idx, val_idx = train_test_split(
            np.arange(len(xx)),
            test_size=self.val_size,
            random_state=self.seed
        )

        def resample(indices: np.array, n: int) -> np.array:
            # chooses which positives to label according to
            # the probabilities computed above
            pos_indices = [i for i in indices if tt[i] > 0.5]
            pos_probs = probs[pos_indices] / probs[pos_indices].sum()
            labl_indices = np.random.choice(
                pos_indices, size=n, replace=False, p=pos_probs.numpy()
            )
            return labl_indices

        assert self.n_labeled is not None
        labl_train_idx = resample(train_idx, self.n_labeled)
        labl_val_idx = resample(
            val_idx, int(len(val_idx) * self.n_labeled / len(train_idx))
        )

        x_train, t_train = xx[train_idx], tt[train_idx]
        p_train = torch.zeros(len(xx)).bool()
        p_train[labl_train_idx] = True
        p_train = p_train[train_idx]
        assert torch.all(t_train[p_train])

        x_val, t_val = xx[val_idx], tt[val_idx]
        p_val = torch.zeros(len(xx)).bool()
        p_val[labl_val_idx] = True
        p_val = p_val[val_idx]
        assert torch.all(t_val[p_val])

        if self.part == 'train':
            return x_train, t_train, p_train
        elif self.part == 'val':
            return x_val, t_val, p_val
        elif self.part == 'test':
            return x_test, t_test, p_test
        else:
            raise ValueError(f'unknown part {self.part}')


class CIFAR10_PUSB(CIFAR10):
    """
    Synthetic dataset with a bias in the positives from [1]
    [1] https://openreview.net/forum?id=rJzLciCqKm
    """

    relative_folder = 'cifar10'

    def load_dataset(self) -> Tuple[Tensor, Tensor, Tensor]:
        # re-label: vehicles = positives, animals = negatives
        # see: https://www.cs.toronto.edu/~kriz/cifar.html for mapping of classes to idx
        pos_classes = [0, 1, 8, 9]
        if self.part == 'test':
            # Test
            x_test = torch.tensor(np.load(f'{self.folder}/X_test.npy'))
            t_test = np.load(f'{self.folder}/Y_test.npy')
            t_test = torch.tensor([t in pos_classes for t in t_test])
            p_test = t_test > 0.5
            return x_test, t_test, p_test

        xx = torch.tensor(np.load(f'{self.folder}/X_train.npy'))
        tt = np.load(f'{self.folder}/Y_train.npy')
        tt = torch.tensor([t in pos_classes for t in tt]).float()

        preds_path = os.path.join(self.folder, 'resampling_predictions.npy')
        if not os.path.exists(preds_path):
            raise RuntimeError(f'resampling predictions not found at {preds_path}'
                               f' - please run "python src/puupl/lib/resampling.py"'
                               ' and try again')
        logits = np.load(preds_path)
        probs = 1 / (1 + np.exp(-logits))
        probs = probs**10

        train_idx, val_idx = train_test_split(
            np.arange(len(xx)),
            test_size=self.val_size,
            random_state=self.seed
        )

        def resample(indices: np.array, n: int) -> np.array:
            pos_indices = [i for i in indices if tt[i] > 0.5]
            pos_probs = probs[pos_indices] / probs[pos_indices].sum()
            labl_indices = np.random.choice(
                pos_indices, size=n, replace=True, p=pos_probs
            )
            return labl_indices

        assert self.n_labeled is not None
        labl_train_idx = resample(train_idx, self.n_labeled)
        labl_val_idx = resample(
            val_idx, int(len(val_idx) * self.n_labeled / len(train_idx))
        )

        x_train, t_train = xx[train_idx], tt[train_idx]
        p_train = torch.zeros(len(xx)).bool()
        p_train[labl_train_idx] = True
        p_train = p_train[train_idx]
        assert torch.all(t_train[p_train])

        x_val, t_val = xx[val_idx], tt[val_idx]
        p_val = torch.zeros(len(xx)).bool()
        p_val[labl_val_idx] = True
        p_val = p_val[val_idx]
        assert torch.all(t_val[p_val])

        if self.part == 'train':
            return x_train, t_train, p_train
        elif self.part == 'val':
            return x_val, t_val, p_val
        elif self.part == 'test':
            return x_test, t_test, p_test
        else:
            raise ValueError(f'unknown part {self.part}')


class FashionMNIST(PuDataset):
    """
    synthetic dataset sampled from fashion MNIST
    """

    relative_folder = 'fashion_mnist'

    def load_dataset(self) -> Tuple[Tensor, Tensor, Tensor]:
        xx = torch.tensor(np.load(f'{self.folder}/X_train.npy'))
        tt = np.load(f'{self.folder}/Y_train.npy')

        # ositives = Trousers, Dress, Sandals, Sneaker, Ankle boot
        # negatives = T-shirt/top, Pullover, Coat, Shirt, Bag
        tt = torch.tensor([t % 2 for t in tt]).float()

        x_train, x_val, t_train, t_val = train_test_split(
            xx, tt,
            test_size=self.val_size,
            random_state=self.seed
        )
        assert self.n_labeled is not None
        p_train = unlabel_positives(t_train, self.n_labeled, self.seed)
        p_val = unlabel_positives(t_val, int(len(t_val) * self.n_labeled / len(t_train)),
                                  self.seed + 55)

        # Test
        x_test = torch.tensor(np.load(f'{self.folder}/X_test.npy'))
        t_test = np.load(f'{self.folder}/Y_test.npy')
        t_test = torch.tensor([t % 2 for t in t_test])
        p_test = t_test > 0.5

        print('F-MNIST - total train pos: %d - total train neg: %d - total test: %d' % (
            torch.sum(tt).item(), torch.sum(1 - tt.int()).item(), len(x_test)))
        print('F-MNIST - labeled training samples: %d - unlabeled training samples: %d' % (
            torch.sum(p_train).item(), torch.sum(1 - p_train.int()).item()
        ))

        if self.part == 'train':
            return x_train, t_train, p_train
        elif self.part == 'val':
            return x_val, t_val, p_val
        elif self.part == 'test':
            return x_test, t_test, p_test
        else:
            raise ValueError(f'unknown part {self.part}')

    def download(self) -> None:
        trainset = torchvision.datasets.FashionMNIST(
            root=self.folder, train=True, download=True, transform=None)
        testset = torchvision.datasets.FashionMNIST(
            root=self.folder, train=False, download=True, transform=None)

        X_train = trainset.data.unsqueeze(1).float()
        X_test = testset.data.unsqueeze(1).float()
        Y_train = trainset.targets
        Y_test = testset.targets

        X_mean, X_std = X_train.mean(axis=(0, 1, 2)), X_train.std(axis=(0, 1, 2))
        X_train = (X_train - X_mean) / X_std
        X_test = (X_test - X_mean) / X_std

        np.save(file=f"{self.folder}/X_train.npy", arr=X_train)
        np.save(file=f"{self.folder}/X_test.npy", arr=X_test)
        np.save(file=f"{self.folder}/Y_train.npy", arr=Y_train)
        np.save(file=f"{self.folder}/Y_test.npy", arr=Y_test)


class IMDB(PuDataset):
    """
    synthetic dataset sampled from IMDB
    """

    relative_folder = 'IMDB'

    def __init__(self, spacy_language: str = 'en_core_web_sm',
                 glove_name: str = '6B', glove_dim: int = 300,
                 relative_folder: Optional[str] = None,
                 **kwargs: Any) -> None:

        self.glove_name = glove_name
        self.glove_dim = glove_dim
        self.spacy_language = spacy_language
        if relative_folder is not None:
            self.relative_folder = relative_folder

        self.collate_functions = {'x': pad_sequence}
        super().__init__(**kwargs)

    def transform(self, x: str) -> Tensor:
        return self.vectors.get_vecs_by_tokens([
            t.text for t in self.nlp.tokenizer(x)
        ], lower_case_backup=True).float()

    def load_dataset(self) -> Tuple[List[str], Tensor, Tensor]:
        train_ds, test_ds, self.nlp, self.vectors = self._get_data_tokenizer_and_vectors()

        if self.part == 'test':
            x_test, t_testl = [], []
            for label, review in test_ds:
                x_test.append(review)
                t_testl.append(1 if label == 'pos' else 0)
            t_test = torch.tensor(t_testl, dtype=torch.float)
            p_test = t_test > 0.5
            print('IMDb - total test:', len(x_test))
            return x_test, t_test, p_test

        xx, ttl = [], []
        for label, review in train_ds:
            xx.append(review)
            ttl.append(1 if label == 'pos' else 0)
        tt = torch.tensor(ttl, dtype=torch.float)

        x_train, x_val, t_train, t_val = train_test_split(
            xx, tt,
            test_size=self.val_size, random_state=self.seed
        )

        assert self.n_labeled is not None
        p_train = unlabel_positives(t_train, self.n_labeled, self.seed)
        p_val = unlabel_positives(t_val, int(len(t_val) * self.n_labeled / len(t_train)),
                                  self.seed + 55)

        print('IMDb - total train pos: %d - total train neg: %d' % (
            torch.sum(tt).item(), torch.sum(1 - tt.int()).item()))
        print('IMDb - labeled training positives: %d - unlabeled training samples: %d' % (
            torch.sum(p_train).item(), torch.sum(1 - p_train.int()).item()
        ))

        if self.part == 'train':
            return x_train, t_train, p_train
        elif self.part == 'val':
            return x_val, t_val, p_val
        else:
            raise ValueError(f'unknown part {self.part}')

    def download(self) -> None:
        spacy_folder = os.path.join(self.folder, self.spacy_language)
        if not os.path.exists(spacy_folder):
            spacy.cli.download(self.spacy_language)  # type: ignore[attr-defined]
            nlp = spacy.load(self.spacy_language)
            nlp.to_disk(spacy_folder)

        _ = self._get_data_tokenizer_and_vectors()

    def _get_data_tokenizer_and_vectors(self) -> Tuple[
        _RawTextIterableDataset, _RawTextIterableDataset, spacy.Language, GloVe
    ]:
        nlp = spacy.load(os.path.join(self.folder, self.spacy_language))
        vectors = GloVe(name=self.glove_name, dim=self.glove_dim,
                        cache=os.path.join(self.folder, 'glove_cache'))
        train_data, test_data = torch_IMDB(root=self.folder)
        return train_data, test_data, nlp, vectors


class BrainCancer(PuDataset):

    relative_folder: str = 'brain_cancer'

    def __init__(self, use_unlabeled_test_set: bool = False, test_size: int = 50000,
                 use_adni_preset: bool = False, **kwargs: Any) -> None:

        self.use_adni_preset = use_adni_preset
        if self.use_adni_preset:
            test_size = 113
            kwargs['val_size'] = 164
            kwargs['n_labeled'] = None
            print('Using ADNI preset!')

        # the test set that comes with this dataset does not have labels,
        # therefore we extract a synthetic test set from the training set.
        self.test_size = test_size

        # when this is true we augment the training set with real unlabeled
        # examples from the real test set.
        self.use_unlabeled_test_set = use_unlabeled_test_set

        self.zipfile: Optional[ZipFile] = None

        super().__init__(**kwargs)

    def download(self) -> None:
        raise NotImplementedError('contact authors for zip file')

    def load_dataset(self) -> Tuple[Any, Tensor, Tensor]:
        zip_path = os.path.join(self.folder, 'Images.zip')
        with ZipFile(zip_path) as zf:
            xx = [f for f in zf.namelist() if f.endswith('.png')]
            all_labels = {}
            with zf.open('train.txt') as f:
                for row in f:
                    img, lbl = row.decode('utf8').split()
                    all_labels[img] = int(lbl)

        if self.use_adni_preset:
            pos_images = np.random.default_rng(self.seed + 55).choice([
                # alzheimer images
                fname for fname in xx if all_labels.get(fname.split('/')[-1]) == 0
            ], size=247 + 34, replace=False).tolist()
            neg_images = np.random.default_rng(self.seed + 56).choice([
                # healthy images
                fname for fname in xx if all_labels.get(fname.split('/')[-1]) == 1
            ], size=575 + 79, replace=False).tolist()

            xx = pos_images + neg_images
            tt = [1] * len(pos_images) + [0] * len(neg_images)

        else:
            # make binary labels: class 1 is negative, classes 0, 2, 3, 4 positive
            # resulting in 203260 neg and 100044+5568+23550+16740=145902 pos
            tt = []
            for fname in xx:
                fname = fname.split('/')[-1]
                t = all_labels.get(fname)
                if t is None:
                    tt.append(-1)
                elif t == 1:
                    tt.append(0)
                else:
                    tt.append(1)

        # split labeled and unlabeled
        xx_lab, xx_unlab, tt_lab = [], [], []
        for x, t in zip(xx, tt):
            if t < 0:
                xx_unlab.append(x)
            else:
                xx_lab.append(x)
                tt_lab.append(t)

        xx_tv, x_test, tt_tv, t_test = train_test_split(
            xx_lab, tt_lab, test_size=self.test_size, random_state=self.seed
        )

        x_train, x_val, t_train, t_val = train_test_split(
            xx_tv, tt_tv, test_size=self.val_size, random_state=self.seed + 1
        )

        if self.n_labeled is not None:
            p_train = unlabel_positives(
                torch.tensor(t_train), self.n_labeled, self.seed + 2
            ).tolist()
            p_val = unlabel_positives(
                t_val, int(len(t_val) * self.n_labeled / len(t_train)), self.seed + 55
            )
        else:
            p_train = t_train[:]

        # if requested include the unlabeled images into the training set
        if self.use_unlabeled_test_set and xx_unlab:
            x_train += xx_unlab
            p_train += [0] * len(xx_unlab)

            # we need to assign some "true" labels to the unlabeled test set
            # this is only used to evaluate the quality of the pseudo-labeling
            # the best we can do is assign labels to maintain the same class
            # balance in the training set, hoping not to bias the evaluation
            # too much
            prior = np.mean(t_train)
            t_train += [
                r <= prior for r in np.random.default_rng(
                    self.seed + 3
                ).random(len(xx_unlab))
            ]

        if self.part == 'train':
            return x_train, torch.tensor(t_train), torch.tensor(p_train)
        elif self.part == 'val':
            return x_val, torch.tensor(t_val), torch.tensor(p_val)
        elif self.part == 'test':
            return x_test, torch.tensor(t_test), torch.tensor(t_test) > 0.5
        else:
            raise ValueError(f'unknown part {self.part}')

    def transform(self, x: str) -> Tensor:
        # keep zipfile opened to avoid overhead
        if self.zipfile is None:
            zip_path = os.path.join(self.folder, 'Images.zip')
            self.zipfile = ZipFile(zip_path)

        with self.zipfile.open(x) as f:
            img = imread(f)

        if img.shape[-1] == 4:  # FIXME how to handle 3/4 channels ?!
            img = img[..., :3]

        return torch.tensor(img).transpose(0, 2)  # channel first


class STL10(PuDataset):
    """
    synthetic dataset sampled from STL10
    """

    relative_folder = 'STL10'

    def __init__(self, val_fold: int, **kwargs: Any) -> None:
        self.val_fold = val_fold
        kwargs['n_labeled'] = kwargs['val_size'] = None
        super().__init__(**kwargs)

    def load_dataset(self) -> Tuple[Tensor, Tensor, Tensor]:
        xx = torch.tensor(np.load(f'{self.folder}/X_train.npy'))
        uu = torch.tensor(np.load(f'{self.folder}/X_unl.npy'))
        tt = torch.tensor(np.load(f'{self.folder}/Y_train.npy')).int()
        tt = ((tt == 0) | (tt == 2) | (tt == 8) | (tt == 9)).float()
        ff = torch.tensor(np.load(f'{self.folder}/folds.npy')).long()

        val_idx = ff[self.val_fold]
        train_idx = torch.cat([
            ff[i] for i in range(10) if i != self.val_fold
        ])

        x_train = torch.cat([xx[train_idx], uu])
        t_train = torch.cat([tt[train_idx], torch.zeros(len(uu))])
        # here we use all available positives for training
        p_train = torch.cat([tt[train_idx], torch.zeros(len(uu))])

        x_val = xx[val_idx]
        t_val = tt[val_idx]
        p_val = tt[val_idx]

        # Test
        x_test = torch.tensor(np.load(f'{self.folder}/X_test.npy'))
        t_test = torch.tensor(np.load(f'{self.folder}/Y_test.npy'))
        t_test = ((t_test == 0) | (t_test == 2) | (t_test == 8) | (t_test == 9)).float()
        p_test = t_test > 0.5

        print('STL10 - total train pos: %d - total train neg: %d - total test: %d' % (
            torch.sum(tt).item(), torch.sum(1 - tt.int()).item(), len(x_test)))
        print('STL10 - labeled training positives: %d - unlabeled training samples: %d' % (
            torch.sum(p_train).item(), torch.sum(1 - p_train.int()).item()
        ))

        if self.part == 'train':
            return x_train, t_train, p_train
        elif self.part == 'val':
            return x_val, t_val, p_val
        elif self.part == 'test':
            return x_test, t_test, p_test
        else:
            raise ValueError(f'unknown part {self.part}')

    def download(self) -> None:
        trainset = torchvision.datasets.STL10(
            root=self.folder, split='train', download=True)
        unlset = torchvision.datasets.STL10(
            root=self.folder, split='unlabeled', download=True)
        testset = torchvision.datasets.STL10(
            root=self.folder, split='test', download=True)

        X_train = trainset.data.astype(np.float32)
        X_unl = unlset.data.astype(np.float32)
        X_test = testset.data.astype(np.float32)

        Y_train = trainset.labels
        Y_train = trainset.labels
        Y_test = testset.labels

        X_mean, X_std = X_train.mean(), X_train.std()
        X_train = (X_train - X_mean) / X_std
        X_unl = (X_unl - X_mean) / X_std
        X_test = (X_test - X_mean) / X_std

        path_to_folds = os.path.join(trainset.root, trainset.base_folder,
                                     trainset.folds_list_file)
        with open(path_to_folds, 'r') as f:
            str_idx = f.read().splitlines()
            folds = np.array([
                np.fromstring(row, dtype=np.int64, sep=' ')
                for row in str_idx
            ])

        np.save(file=f"{self.folder}/X_train.npy", arr=X_train)
        np.save(file=f"{self.folder}/X_unl.npy", arr=X_unl)
        np.save(file=f"{self.folder}/X_test.npy", arr=X_test)
        np.save(file=f"{self.folder}/folds.npy", arr=folds)
        np.save(file=f"{self.folder}/Y_train.npy", arr=Y_train)
        np.save(file=f"{self.folder}/Y_test.npy", arr=Y_test)


def get_dataset(params: Dict[str, Any]) -> Tuple[PuDataset, PuDataset, PuDataset]:
    cls = params.pop('class').lower()

    dset_classes: Dict[str, Type[PuDataset]] = {
        'mnist': MNIST,
        'cifar10': CIFAR10,
        'skewed_cifar10': Skewed_CIFAR10,
        'cifar10_pusb': CIFAR10_PUSB,
        'fashion_mnist': FashionMNIST,
        'imdb': IMDB,
        'brain_cancer': BrainCancer,
        'stl10': STL10,
    }

    dset = dset_classes.get(cls)
    if dset is None:
        raise ValueError(f'dataset "{cls}" not found')

    train_data = dset(part='train', **params)
    val_data = dset(part='val', **params)
    test_data = dset(part='test', **params)

    return train_data, val_data, test_data
