import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as datasets
import numpy as np
import util
import augmentations
from augmentations import augment_list, augment_list_grayscale
import torchvision.transforms as transforms
from PIL import Image


normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])

def get_train_dataset(dataset):
    default_augs = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize])
    no_default_augs = transforms.Compose([
            transforms.ToTensor(),
            normalize])

    if dataset == "CIFAR10":
        return datasets.CIFAR10(root='./data', train=True, transform=default_augs, download=True)
    elif dataset == "CIFAR100":
        return datasets.CIFAR100(root='./data', train=True, transform=default_augs, download=True)
    elif dataset == "SVHN":
        return datasets.SVHN(root='./data', split='train', transform=default_augs, download=True)
    elif dataset == "MNIST":
        return datasets.MNIST(root='./data', train=True, transform=transforms.Compose([
                transforms.ToTensor(),
            ]), download=True)
    elif dataset == "Reduced-MNIST":
        return ReducedMNIST(train=True)
    elif dataset == "CIFAR10-IMB":
        return CIFAR10_IMB(train=True)
    else:
        raise ValueError(f"No such dataset: {dataset}")


def get_val_dataset(dataset):
    if dataset == "CIFAR10":
        return datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]), download=True)
    elif dataset == "CIFAR100":
        return datasets.CIFAR100(root='./data', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ]), download=True)
    elif dataset == "SVHN":
        return datasets.SVHN(root='./data', split="test", transform=transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ]), download=True)
    elif dataset == "MNIST":
        return datasets.MNIST(root='./data', train=False, transform=transforms.Compose([
                transforms.Lambda(lambda x: x.rotate(30)), # Rotate left by 30d
                #transforms.Lambda(lambda x: x.transform(x.size, Image.AFFINE, (1, 0, 0, 0, 1, -5))), # Move down by 5
                #transforms.Lambda(lambda x: x.transform(x.size, Image.AFFINE, (1, 0.1, 0, 0, 1, 0))), # ShearX 0.1
                transforms.ToTensor(),
            ]), download=True)
    elif dataset == "Reduced-MNIST":
        return ReducedMNIST(train=False)
    elif dataset == "CIFAR10-IMB":
        return CIFAR10_IMB(train=False)
    else:
        raise ValueError("No such dataset")


class ReducedMNIST(Dataset):
    def __init__(self, train, max_per_class=2000):
        self.indices = []
        if train:
            self.mnist = get_train_dataset("MNIST")
            count = {0:0, 1:0, 2:0}
            for i, (x, y) in enumerate(self.mnist):
                if y <= 2 and count[y] <= max_per_class:
                    count[y] += 1
                    self.indices.append(i)
        else:
            self.mnist = get_val_dataset("MNIST")
            for i, (x, y) in enumerate(self.mnist):
                if y <= 2:
                    self.indices.append(i)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, index):
        mnist_index = self.indices[index]
        return self.mnist[mnist_index]


class CIFAR10_IMB(Dataset):
    def __init__(self, train, imb_factor = 0.01):
        self.indices = []
        if train:
            self.ds = get_train_dataset("CIFAR10")
        else:
            self.ds = get_val_dataset("CIFAR10")
        img_max = len(self.ds) // 10
        num_sample = [img_max * (imb_factor**(i/9)) for i in range(10)]
        num_sample = [int(n) for n in num_sample]
        count = {i:0 for i in range(10)}
        for i, (_, y) in enumerate(self.ds):
            if count[y] <= num_sample[y]:
                count[y] += 1
                self.indices.append(i)
        print(f"Initialized CIFAR10_IMB (Train: {train}) with {len(self.indices)} examples using imb_factor of {imb_factor}")
        print(f"Distribution: {num_sample}")

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, index):
        cifar10_index = self.indices[index]
        return self.ds[cifar10_index]


class IndexedDataset(Dataset):
    def __init__(self, dataset, train=True):
        if train:
            self.ds = get_train_dataset(dataset)
        else:
            raise ValueError("No such dataset")

    def __getitem__(self, index):
        data, target = self.ds[index]
        return data, target, index

    def __len__(self):
        return len(self.ds)


class WeightedAugmentedDataset(Dataset):
    def __init__(self, ds):
        self.ds = ds

        self.data = []
        self.weights = []
        self.augmented_data = []
        self.augmented_weights = []



    def init_full_dataset(self, weights=None, keep_augment=False):
        self.data = [d for d in self.ds]
        self.weights = [1 for _ in range(len(self.ds))]

        if not keep_augment:
            self.reset_augments()


    def init_subset(self, subset, weights=None, keep_augment=False):
        self.data = [self.ds[i] for i in subset]
        if weights is None:
            self.weights = [1 for _ in range(len(subset))]
        else:
            self.weights = [w for w in weights]

        if not keep_augment:
            self.reset_augments()


    def reset_augments(self):
        self.augmented_data = []
        self.augmented_weights = []


    def augment_all(self, weights=None, C=1, L=2, S=1, model=None, bs=128):
        # C - number of examples to augment
        # L - augmentations per example
        self.augment_subset([i for i in range(len(self.ds))], weights=weights, C=C, L=L, S=S, model=model, bs=bs)


    def augment_subset(self, subset, weights=None, C=2, L=2, S=1, model=None, bs=128):
        # C - number of examples to augment
        # L - augmentations per example
        assert S <= C
        assert weights is None or len(weights) == len(subset)

        if weights is None:
            weights = [1 for _ in range(len(subset))]
        else:
            weights = [w for w in weights]

        augmented_data_full = []
        augmented_weights_full = []

        for _ in range(C):
            labels = [self.ds[i][1] for i in subset]
            augmented_input = [self._augment(self.ds[i][0], L) for i in subset]

            augmented_data_full += list(zip(augmented_input, labels))
            augmented_weights_full += [w for w in weights]

        if C == S:
            self.augmented_data = augmented_data_full
            self.augmented_weights = augmented_weights_full

        else:
            criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()
            losses = []
            with torch.no_grad():
                for i in range(0, len(augmented_data_full), bs):
                    end_idx = min(i + bs, len(augmented_data_full))
                    batch_x = [x[0] for x in augmented_data_full[i:end_idx]]
                    batch_x = torch.stack(batch_x, 0).cuda()
                    batch_y = np.array([x[1] for x in augmented_data_full[i:end_idx]])
                    batch_y = torch.from_numpy(batch_y).cuda()
                    batch_loss = criterion(model.forward(batch_x), batch_y).cpu().numpy()
                    losses.append(batch_loss)

            losses = np.concatenate(losses)

            self.augmented_data = []
            self.augmented_weights = []

            for i in range(0, len(augmented_data_full), C):
                data_subset = augmented_data_full[i:i+C]
                weight_subset = augmented_weights_full[i:i+C]
                loss_subset = losses[i:i+C]
                indices = np.argsort(loss_subset)[::-1]
                indices = indices[:S]

                self.augmented_data += [data_subset[idx] for idx in indices]
                self.augmented_weights += [weight_subset[idx] for idx in indices]


    def _augment(self, input_unaugmented, L):
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        mean = mean[:,None,None]
        std = std[:,None,None]

        is_grayscale = input_unaugmented.shape[0] == 1
        # Undo normalization, augment, normalize
        input_unaugmented = input_unaugmented.numpy()

        if not is_grayscale:
            input_unaugmented *= std
            input_unaugmented += mean
            input_unaugmented *= 255
            input_unaugmented = input_unaugmented.astype('uint8')
            input_unaugmented = input_unaugmented.transpose(1,2,0)
            input_unaugmented = Image.fromarray(input_unaugmented)


            augment_func = lambda x : util.augment_single(x, L=L, aug_list=augment_list())

            input_augmented = augment_func(input_unaugmented)
            input_augmented = np.array(input_augmented, dtype='float32') / 255
            input_augmented = input_augmented.transpose(2,0,1)
            input_augmented -= mean
            input_augmented /= std

        else:
            input_unaugmented = np.squeeze(input_unaugmented)
            input_unaugmented *= 255
            input_unaugmented = input_unaugmented.astype('uint8')
            input_unaugmented = Image.fromarray(input_unaugmented)

            augment_func = lambda x : util.augment_single(x, L=L, aug_list=augment_list_grayscale())

            input_augmented = augment_func(input_unaugmented)
            input_augmented = np.array(input_augmented, dtype='float32') / 255
            input_augmented = np.expand_dims(input_augmented, 0)

        input_augmented = torch.from_numpy(input_augmented)

        return input_augmented


    def __getitem__(self, index):
        if len(self.data) == 0:
            raise ValueError("Dataset not initializaed!")

        if index < len(self.data):
            data, target = self.data[index]
            weight = self.weights[index]
        else:
            aug_index = index - len(self.data)
            data, target = self.augmented_data[aug_index]
            weight = self.augmented_weights[aug_index]

        return data, target, weight


    def __len__(self):
        return len(self.data) + len(self.augmented_data)

