"""Prepares the MNIST or label corrupted MNIST data sets for training and
estimation of the terms in the decompositions.
"""
from __future__ import annotations

from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import random_split
import torch

import numpy as np

from env.user import ROOT

class PrepareMNIST:
    """Prepare the MNIST data set.
    
    Imports the MNIST data set and implements a subsampling method.
    """
    def __init__(
            self,
            *,
            val_set: bool=False,
            val_seed: int=42,
    ) -> None:
        self.train_set, self.test_set = self.fetch_data()
        if val_set:
            generator = torch.Generator().manual_seed(val_seed)
            self.train_set, self.val_set = random_split(
                dataset=self.train_set,
                lengths=[0.9, 0.1],
                generator=generator,
            )
        
    def subsample(
            self,
            trial_seed: int,
            split: list=[0.9, 0.1],
    ) -> None:
        generator = torch.Generator().manual_seed(trial_seed)
        self.train_subset, _ = random_split(
            dataset=self.train_set,
            lengths=split,
            generator=generator,
        )

    def fetch_data(self):
        train_data = datasets.MNIST(
            root=ROOT + '/data',
            train=True,
            transform=ToTensor(),
            download=True
        )
        test_data = datasets.MNIST(
            root=ROOT + '/data',
            train=False,
            transform=ToTensor(),
            download=True
        )

        return train_data, test_data
    
class PrepareCorruptMNIST:
    """Prepare the label corrupted MNIST data set.
    
    Imports the MNIST data set and corrupts a percentage of the labels. Also,
    implements a subsampling method.
    """
    def __init__(
            self,
            *,
            val_set: bool = False,
            val_seed: int = 42,
            train_corrupt_percentage = 0.2,
            train_corrupt_seed = 123
    ) -> None:
        print('CORRUPT PERCENTAGE: ' + str(train_corrupt_percentage))
        self.train_set, self.test_set = self.fetch_data(train_corrupt_percentage, train_corrupt_seed)
        if val_set:
            generator = torch.Generator().manual_seed(val_seed)
            self.train_set, self.val_set = random_split(
                dataset=self.train_set,
                lengths=[0.9, 0.1],
                generator=generator,
            )

    def subsample(
            self,
            trial_seed: int,
            split: list = [0.9, 0.1],
    ) -> None:
        generator = torch.Generator().manual_seed(trial_seed)
        self.train_subset, _ = random_split(
            dataset=self.train_set,
            lengths=split,
            generator=generator,
        )

    def fetch_data(self, train_corrupt_percentage, train_corrupt_seed):
        train_data = datasets.MNIST(
            root=ROOT + '/data',
            train=True,
            transform=ToTensor(),
            download=True
        )
        test_data = datasets.MNIST(
            root=ROOT + '/data',
            train=False,
            transform=ToTensor(),
            download=True
        )
        if train_corrupt_percentage > 0.0:
            train_data, data_meta = self.corrupt_labels(train_data, train_corrupt_percentage, train_corrupt_seed)

        return train_data, test_data

    def corrupt_labels(self, dataset, train_corrupt_percentage, train_corrupt_seed):
        np.random.seed(train_corrupt_seed)
        min_label = dataset.targets.min()
        max_label = dataset.targets.max()

        labels = np.arange(min_label, max_label - min_label + 1)
        meta = {}
        meta['corrupt_label_index'] = []
        meta['corrupt_label_original'] = []
        for i in range(len(dataset.targets)):
            if np.random.binomial(1, train_corrupt_percentage):
                # Randomly chosen from all except true label
                old_x = int(dataset.targets[i])
                choices = list(labels)
                choices.remove(old_x)
                new_x = np.random.choice(choices)
                dataset.targets[i] = int(new_x)
                meta['corrupt_label_index'].append(i)
                meta['corrupt_label_original'].append(old_x)
        return dataset, meta