import jax.numpy as np
from jax import random
from jax.config import config
import jax
import torch
#config.update('jax_enable_x64', True)
import numpy as onp
import torchvision


class MNIST:
    def __init__(self, n_train, n_test, flat=False, identity=False, binary=False,
                 two_class=False, one_hot=False, permute_key=None, scale_dim=None):
        self.n_train = n_train
        self.n_test = n_test
        self.identity = identity
        self.flat = flat
        self.binary = binary
        self.two_class = two_class
        self.one_hot = one_hot
        self.permute_key = permute_key
        self.scale_dim = scale_dim

        self.get_data()

        if self.binary and self.two_class is False:
            self.binarize()

        if self.two_class:
            self.two_classify()

    def get_data(self):
        trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True)
        testset = torchvision.datasets.MNIST(root='./data', train=False, download=True)

        if self.scale_dim is not None:
            self.x_train = torch.unsqueeze(trainset.train_data[:, :, :], 1) / 255.0
            self.x_train = torch.nn.functional.interpolate(self.x_train, (self.scale_dim, self.scale_dim)).squeeze().numpy()
            self.x_test = torch.unsqueeze(testset.train_data[:, :, :], 1) / 255.0
            self.x_test = torch.nn.functional.interpolate(self.x_test, (self.scale_dim, self.scale_dim)).squeeze().numpy()
        else:
            self.x_train = trainset.train_data.numpy() / 255.0
            self.x_test = testset.test_data.numpy() / 255.0

        if self.two_class:
            n_train = 50000
            n_test = 10000
        else:
            n_train = self.n_train
            n_test = self.n_test

        if self.permute_key is None:
            self.x_train = onp.expand_dims(self.x_train[:n_train, :, :], 3)
            self.y_train = onp.expand_dims(trainset.train_labels.numpy()[:n_train], axis=(0, 2))
        else:
            inds = random.choice(key=self.permute_key, a=np.array(range(50000)), shape=(self.n_train,), replace=False)
            self.x_train = onp.expand_dims(self.x_train[inds, :, :], 3)
            self.y_train = onp.expand_dims(trainset.train_labels.numpy()[inds, ], axis=(0, 2))

        self.x_test = onp.expand_dims(self.x_test[:n_test, :, :], 3)
        self.y_test = onp.expand_dims(testset.test_labels[:n_test].numpy(), axis=(0, 2))

        self.x_train = self.x_train
        self.x_test = self.x_test

        if self.flat:
            self.x_train = onp.reshape(self.x_train, (n_train, -1))
            self.x_test = onp.reshape(self.x_test, (n_test, -1))

        if self.identity:
            self.y_train = self.x_train
            self.y_test = self.x_test

        if self.one_hot:
            self.y_train = np.expand_dims(jax.nn.one_hot(self.y_train, 10).squeeze(), axis=0)
            self.y_test = np.expand_dims(jax.nn.one_hot(self.y_test, 10).squeeze(), axis=0)

    def randomize(self, noise_level, key, runs):
        if self.binary:
            key, new_key = random.split(key, 2)
            indices = random.bernoulli(key, shape=(self.n_train, runs), p=noise_level)
            noise = random.bernoulli(key, shape=(self.n_train, runs), p=0.5)
            y_train_noisy = self.y_train - 2 * noise * indices * self.y_train
            self.y_correct = np.concatenate([self.y_train for _ in range(runs)], axis=1)
            self.y_train = y_train_noisy
            self.num_correct = 2 * (1 - np.mean(self.y_correct == self.y_train, axis=1))
        else:
            key, new_key, new_key_2 = random.split(key, 3)
            indices = random.bernoulli(new_key, shape=(1, self.n_train, 1), p=noise_level)
            noise = random.choice(new_key_2, a=np.array(range(10)), shape=(runs, self.n_train))
            y_train_noisy = self.y_train * (1 - indices) + indices * jax.nn.one_hot(noise, 10)

            self.y_correct = np.concatenate([self.y_train for _ in range(runs)], axis=0)
            self.y_train = y_train_noisy

    def binarize(self):
        self.y_train[self.y_train <= 4] = 1
        self.y_train[self.y_train > 4] = -1

        self.y_test[self.y_test <= 4] = 1
        self.y_test[self.y_test > 4] = -1

    def two_classify(self):
        indices_1, _ = np.where(self.y_train == 3)
        indices_2, _ = np.where(self.y_train == 2)

        x_train_1 = self.x_train[indices_1[:self.n_train // 2].squeeze(), :]
        x_train_2 = self.x_train[indices_2[:self.n_train // 2], :]

        self.x_train = np.concatenate([x_train_1, x_train_2], axis=0)

        y_train_1 = np.ones(self.x_train.shape[0] // 2)
        y_train_2 = -np.ones(self.x_train.shape[0] // 2)

        self.y_train = np.expand_dims(np.concatenate([y_train_1, y_train_2], axis=0), axis=1)

        indices_1, _ = np.where(self.y_test == 3)
        indices_2, _ = np.where(self.y_test == 2)

        x_test_1 = self.x_test[indices_1[:self.n_test // 2], :]
        x_test_2 = self.x_test[indices_2[:self.n_test // 2], :]

        self.x_test = np.concatenate([x_test_1, x_test_2], axis=0)

        y_test_1 = np.ones(self.x_test.shape[0] // 2)
        y_test_2 = -np.ones(self.x_test.shape[0] // 2)

        self.y_test = np.expand_dims(np.concatenate([y_test_1, y_test_2], axis=0), axis=1)


class RandomData:
    def __init__(self, n_train, n_test, num_classes, dim, key):
        self.n_train = n_train
        self.n_test = n_test
        self.num_classes = num_classes
        self.dim = dim
        self.key = key

        self.get_data()

    def get_data(self):
        keys = random.split(self.key, 4)
        self.x_train = random.normal(keys[0], shape=(self.n_train, self.dim))
        self.x_test = random.normal(keys[1], shape=(self.n_test, self.dim))

        self.y_train = 2 * random.bernoulli(keys[2], p=0.5, shape=(self.n_train, 1)) - 1
        self.y_test = 2 * random.bernoulli(keys[3], p=0.5, shape=(self.n_test, 1)) - 1


class CIFAR10:
    def __init__(self, n_train, n_test, flat=True, binary=False, one_hot=False, permute_key=None, two_class=False,
                 scale_dim=None, num_channels=None):
        self.n_train = n_train
        self.n_test = n_test
        self.flat = flat
        self.binary = binary
        self.one_hot = one_hot
        self.permute_key = permute_key
        self.two_class = two_class
        self.scale_dim = scale_dim
        self.num_channels = num_channels

        self.get_data()
        if self.binary:
            self.binarize()

    def get_data(self):
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False)
        testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False)

        if self.two_class:
            n_train = 50000
            n_test = 10000
        else:
            n_train = self.n_train
            n_test = self.n_test

        if self.scale_dim is not None:
            if self.num_channels is not None:
                self.x_train = torch.tensor(trainset.data[:, :, :, :1]) / 255.0
                self.x_train = torch.transpose(self.x_train, 1, 3)
                self.x_train = torch.nn.functional.interpolate(self.x_train, (self.scale_dim, self.scale_dim)).squeeze()
                self.x_train = self.x_train.numpy()

                self.x_test = torch.tensor(testset.data[:, :, :, :1]) / 255.0
                self.x_test = torch.transpose(self.x_test, 1, 3)
                self.x_test = torch.nn.functional.interpolate(self.x_test, (self.scale_dim, self.scale_dim)).squeeze()
                self.x_test = self.x_test.numpy()
            else:
                self.x_train = torch.tensor(trainset.data[:, :, :]) / 255.0
                self.x_train = torch.transpose(self.x_train, 1, 3)
                self.x_train = torch.nn.functional.interpolate(self.x_train, (self.scale_dim, self.scale_dim)).squeeze()
                self.x_train = torch.transpose(self.x_train, 1, 3).numpy()

                self.x_test = torch.tensor(testset.data[:, :, :]) / 255.0
                self.x_test = torch.transpose(self.x_test, 1, 3)
                self.x_test = torch.nn.functional.interpolate(self.x_test, (self.scale_dim, self.scale_dim)).squeeze()
                self.x_test = torch.transpose(self.x_test, 1, 3).numpy()
        else:
            self.x_train = trainset.data / 255.0
            self.x_test = testset.data / 255.0

        if self.permute_key is None:
            self.x_train = onp.expand_dims(self.x_train[:n_train, :, :], 3)
            self.y_train = onp.expand_dims(onp.array(trainset.targets[:n_train]), axis=(0, 2))
        else:
            inds = random.choice(key=self.permute_key, a=np.array(range(50000)), shape=(self.n_train,), replace=False)
            self.x_train = onp.expand_dims(self.x_train[inds, :, :], 3)
            self.y_train = onp.expand_dims(onp.array(trainset.targets)[inds, ], axis=(0, 2))

        self.x_test = onp.expand_dims(self.x_test[:n_test, :, :], 3)
        self.y_test = onp.expand_dims(onp.array(testset.targets[:n_test]), axis=(0, 2))

        if self.flat:
            self.x_train = onp.reshape(self.x_train, (self.n_train, -1))
            self.x_test = onp.reshape(self.x_test, (self.n_test, -1))
        else:
            self.x_train = self.x_train.squeeze()
            self.x_test = self.x_test.squeeze()

        if self.one_hot:
            self.y_train = np.expand_dims(jax.nn.one_hot(self.y_train, 10).squeeze(), axis=0)
            self.y_test = np.expand_dims(jax.nn.one_hot(self.y_test, 10).squeeze(), axis=0)

    def binarize(self):
        self.y_train[self.y_train <= 4] = 1
        self.y_train[self.y_train > 4] = -1

        self.y_test[self.y_test <= 4] = 1
        self.y_test[self.y_test > 4] = -1

    def randomize(self, noise_level, key, runs):
        if self.binary:
            key, new_key, new_key_2 = random.split(key, 3)
            indices = random.bernoulli(new_key, shape=(self.n_train, runs), p=noise_level)
            noise = random.bernoulli(new_key_2, shape=(self.n_train, runs), p=0.5)
            y_train_noisy = self.y_train - 2 * noise * indices * self.y_train
            self.y_correct = np.concatenate([self.y_train for _ in range(runs)], axis=1)
            self.y_train = y_train_noisy
            self.num_correct = 2 * (1 - np.mean(self.y_correct == self.y_train, axis=1))
        else:
            key, new_key = random.split(key, 2)
            indices = random.bernoulli(key, shape=(1, self.n_train, 1), p=noise_level)
            noise = random.choice(key, a=np.array(range(10)), shape=(runs, self.n_train))
            y_train_noisy = self.y_train * (1 - indices) + indices * jax.nn.one_hot(noise, 10)

            self.y_correct = np.concatenate([self.y_train for _ in range(runs)], axis=0)
            self.y_train = y_train_noisy
