from __future__ import annotations

from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import random_split
import torch
import numpy as np
from env.user import ROOT

def fetch_data():
    train_data = datasets.CIFAR10(
        root=ROOT + '/data',
        train=True,
        transform=ToTensor(),
        download=True
    )
    test_data = datasets.CIFAR10(
        root=ROOT + '/data',
        train=False,
        transform=ToTensor(),
        download=True
    )

    return train_data, test_data

class PrepareCorruptCIFAR:
    def __init__(
            self,
            *,
            val_set: bool = False,
            val_seed: int = 42,
            train_corrupt_percentage = 0.1,
            train_corrupt_seed = 123
    ) -> None:
        print('CORRUPT PERCENTAGE: ' + str(train_corrupt_percentage))
        self.train_set, self.test_set = self.fetch_data(train_corrupt_percentage, train_corrupt_seed)
        if val_set:
            generator = torch.Generator().manual_seed(val_seed)
            self.train_set, self.val_set = random_split(
                dataset=self.train_set,
                lengths=[0.9, 0.1],
                generator=generator,
            )

    def subsample(
            self,
            trial_seed: int,
            split: list = [0.9, 0.1],
    ) -> None:
        generator = torch.Generator().manual_seed(trial_seed)
        self.train_subset, _ = random_split(
            dataset=self.train_set,
            lengths=split,
            generator=generator,
        )

    def fetch_data(self, train_corrupt_percentage, train_corrupt_seed):
        train_data = datasets.CIFAR10(
            root=ROOT + '/data',
            train=True,
            transform=ToTensor(),
            download=True
        )
        test_data = datasets.CIFAR10(
            root=ROOT + '/data',
            train=False,
            transform=ToTensor(),
            download=True
        )
        if train_corrupt_percentage > 0.0:
            train_data, data_meta = self.corrupt_labels(train_data, train_corrupt_percentage, train_corrupt_seed)

        return train_data, test_data

    def corrupt_labels(self, dataset, train_corrupt_percentage, train_corrupt_seed):

        np.random.seed(train_corrupt_seed)
        min_label = np.array(dataset.targets).min()
        max_label = np.array(dataset.targets).max()

        labels = np.arange(min_label, max_label - min_label + 1)
        meta = {}
        meta['corrupt_label_index'] = []
        meta['corrupt_label_original'] = []
        for i in range(len(dataset.targets)):
            if np.random.binomial(1, train_corrupt_percentage):
                # Randomly chosen from all except true label
                old_x = int(dataset.targets[i])
                choices = list(labels)
                choices.remove(old_x)
                new_x = np.random.choice(choices)
                dataset.targets[i] = int(new_x)
                meta['corrupt_label_index'].append(i)
                meta['corrupt_label_original'].append(old_x)
        return dataset, meta
