import itertools
import json
import random
from pathlib import Path
from typing import Callable, List, Iterable, Tuple

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset as TorchDataset
from torch.utils.data import Subset
from collections import Counter, defaultdict

_DATA_ROOT = Path(__file__).parent


class ShuffleIterator:
    def __init__(self, iterable):
        self.iterable = iterable
        self.buffer = []
        self.index = 0

    def __iter__(self):
        return self

    def __next__(self):
        if self.index >= len(self.buffer):
            self.buffer = list(self.iterable)
            random.shuffle(self.buffer)
            self.index = 0

        if not self.buffer:
            self.buffer = list(self.iterable)
            random.shuffle(self.buffer)

        item = self.buffer[self.index]
        self.index += 1
        return item


def MNIST_datasets():
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )

    datasets = {
        "train": torchvision.datasets.MNIST(
            root=str(_DATA_ROOT), train=True, download=True, transform=transform
        ),
        "test": torchvision.datasets.MNIST(
            root=str(_DATA_ROOT), train=False, download=True, transform=transform
        ),
    }
    return datasets


def EMNIST_datasets():
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )

    train_set = torchvision.datasets.EMNIST(
        root=_DATA_ROOT, split='balanced', train=True, download=True, transform=transform
    )
    test_set = torchvision.datasets.EMNIST(
        root=_DATA_ROOT, split='balanced', train=False, download=True, transform=transform
    )

    target_class_indices = list(range(16))
    
    def filter_dataset(dataset, target_indices):
        indices = [i for i, (_, label) in enumerate(dataset) if label in target_indices]
        return Subset(dataset, indices)

    datasets = {
        'train': filter_dataset(train_set, target_class_indices),
        'test': filter_dataset(test_set, target_class_indices),
    }
    return datasets


def KMNIST_datasets():
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )

    datasets = {
        "train": torchvision.datasets.KMNIST(
            root=str(_DATA_ROOT), train=True, download=True, transform=transform
        ),
        "test": torchvision.datasets.KMNIST(
            root=str(_DATA_ROOT), train=False, download=True, transform=transform
        ),
    }
    return datasets


def CIFAR_datasets():
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    transform = transforms.Compose([
        transforms.ToTensor(), 
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
    )

    datasets = {
        "train": torchvision.datasets.cifar.CIFAR10(
            root=str(_DATA_ROOT), train=True, download=True, transform=train_transform
        ),
        "test": torchvision.datasets.cifar.CIFAR10(
            root=str(_DATA_ROOT), train=False, download=True, transform=transform
        ),
    }
    return datasets

def CIFAR100_datasets():
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    transform = transforms.Compose([
        transforms.ToTensor(), 
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
    )

    train_set = torchvision.datasets.CIFAR100(
        root=str(_DATA_ROOT), train=True, download=True, transform=train_transform
    )
    test_set = torchvision.datasets.CIFAR100(
        root=str(_DATA_ROOT), train=False, download=True, transform=transform
    )

    target_class_indices = list(range(16))
    
    def filter_dataset(dataset, target_indices):
        indices = [i for i, (_, label) in enumerate(dataset) if label in target_indices]
        return Subset(dataset, indices)

    datasets = {
        'train': filter_dataset(train_set, target_class_indices),
        'test': filter_dataset(test_set, target_class_indices),
    }
    
    return datasets


def SVHN_datasets():
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )

    datasets = {
        "train": torchvision.datasets.SVHN(
            root=str(_DATA_ROOT), split='train', download=True, transform=transform
        ),
        "test": torchvision.datasets.SVHN(
            root=str(_DATA_ROOT), split='test', download=True, transform=transform
        ),
    }
    return datasets


def get_datasets(dataset_name: str = "MNIST", digit_base=10):
    if dataset_name == "MNIST":
        if digit_base == 16:
            return EMNIST_datasets()
        return MNIST_datasets()
    elif dataset_name == 'KMNIST':
        return KMNIST_datasets()
    elif dataset_name in ["CIFAR10", 'CIFAR']:
        if digit_base == 16:
            return CIFAR100_datasets()
        return CIFAR_datasets()
    elif dataset_name == 'SVHN':
        return SVHN_datasets()


def digits_to_number(digits, digit_base=10) -> int:
    number = 0
    for d in digits:
        number *= digit_base
        number += d
    return number


def addition(n: int, dataset: str, seed=None, train: bool = True, z_list=None, sequence_num=30000, digit_base=10):
    """Returns a dataset for binary addition"""
    return DigitsOperator(
        dataset_name=dataset,
        function_name="addition" if n == 1 else "multi_addition",
        operator=sum,
        size=n,
        arity=2,
        seed=seed,
        train=train,
        sequence_num=sequence_num,
        z_list=z_list,
        digit_base=digit_base,
    )


class DigitsOperator(TorchDataset):
    def __getitem__(self, index: int) -> Tuple[list, list, int]:
        l1, l2 = self.data[index]
        label = self._get_label(index, self.digit_base)
        l1 = [self.images[x] for x in l1]
        l2 = [self.images[x] for x in l2]
        return l1, l2, label

    def balance_indices(self):
        balance_size = sorted(Counter(self.dataset.labels).items())[0][1]
        labels_dist = defaultdict(int)
        sampler_iter = ShuffleIterator(list(range(len(self.dataset))))
        balanced_indices = []
        while len(balanced_indices) < balance_size * 10:
            sample = next(sampler_iter)
            sampled_class = self.dataset.labels[sample]
            if labels_dist[sampled_class] >= balance_size:
                continue
            balanced_indices.append(sample)
            labels_dist[sampled_class] += 1
        return balanced_indices

    def indices(self):
        return list(range(len(self.dataset)))

    def __init__(
        self,
        dataset_name: str,
        function_name: str,
        operator: Callable[[List[int]], int],
        size=1,
        arity=2,
        seed=None,
        train: bool = True,
        sequence_num: int = 30000,
        z_list: List[int] = None,
        digit_base: int = 10,
    ):
        """Generic dataset for operator(img, img) style datasets.

        :param dataset_name: Dataset to use (train, val, test)
        :param function_name: Name of Problog function to query.
        :param operator: Operator to generate correct examples
        :param size: Size of numbers (number of digits)
        :param arity: Number of arguments for the operator
        :param seed: Seed for RNG
        :param z_list: List of digits to include in the dataset
        """
        super(DigitsOperator, self).__init__()
        assert size >= 1
        assert arity >= 1
        self.datasets = get_datasets(dataset_name, digit_base)
        self.dataset = self.datasets["train" if train else "test"]
        self.function_name = function_name
        self.operator = operator
        self.size = size
        self.arity = arity
        self.seed = seed
        self.z_list = z_list
        self.digit_base = digit_base
        mnist_indices = self.indices()
        
        if seed is not None:
            rng = random.Random(seed)
            rng.shuffle(mnist_indices)
        
        self.images = [self.dataset[x][0] for x in range(len(self.dataset))]
        self.concept_labels = [self.dataset[x][1] for x in range(len(self.dataset))]
        # Filter indices based on z_list
        if self.z_list is not None:
            mnist_indices = [
                idx for idx in mnist_indices if self.concept_labels[idx] in self.z_list
            ]
        
        dataset_iter = ShuffleIterator(mnist_indices)
        # Build list of examples (mnist indices)
        self.data = []
        try:
            while len(self.data) < sequence_num:
                example = [
                    [next(dataset_iter) for _ in range(self.size)]
                    for _ in range(self.arity)
                ]
                self.data.append(example)
        except StopIteration:
            pass

    def to_file_repr(self, i):
        """Old file represenation dump. Not a very clear format as multi-digit arguments are not separated"""
        return f"{tuple(itertools.chain(*self.data[i]))}\t{self._get_label(i, self.digit_base)}"

    def to_json(self):
        """
        Convert to JSON, for easy comparisons with other systems.

        Format is [EXAMPLE, ...]
        EXAMPLE :- [ARGS, expected_result]
        ARGS :- [MULTI_DIGIT_NUMBER, ...]
        MULTI_DIGIT_NUMBER :- [mnist_img_id, ...]
        """
        data = [(self.data[i], self._get_label(i, self.digit_base)) for i in range(len(self))]
        return json.dumps(data)

    def _get_label(self, i, digit_base):
        mnist_indices = self.data[i]
        # Figure out what the ground truth is, first map each parameter to the value:
        ground_truth = [
            digits_to_number([self.concept_labels[j] for j in i], digit_base) for i in mnist_indices
        ]
        # Then compute the expected value:
        expected_result = self.operator(ground_truth)
        return expected_result

    def _get_symbol_label(self, i: int):
        mnist_indices = self.data[i]
        # Figure out what the ground truth is, first map each parameter to the value:
        ground_truth = [self.concept_labels[j] for i in mnist_indices for j in i]
        return ground_truth

    def __len__(self):
        return len(self.data)


def get_data(train=True, get_pseudo_label=False, n=1, digit_base=10, dataset="MNIST"):
    mnistDataset = addition(n, dataset, train=train, digit_base=digit_base)
    X, Y, Z = [], [], []
    for idx in range(len(mnistDataset)):
        x1, x2, y = mnistDataset[idx]
        z = mnistDataset._get_symbol_label(idx)
        X.extend([x1 + x2]), Y.append(y), Z.extend([z])
    if get_pseudo_label:
        return X, Z, Y
    return X, None, Y

def get_phase_data(
        train=True, 
        get_pseudo_label=False, 
        n=2,
        min_sequence_num=32,
        z_lists=[list(range(10))],
        digit_base=10,
        dataset="MNIST",
    ):
    mnistDataset = addition(n, dataset, train=train, digit_base=digit_base)
    all_X, all_Y, all_Z = [], [], []
    for idx in range(len(mnistDataset)):
        x1, x2, y = mnistDataset[idx]
        z = mnistDataset._get_symbol_label(idx)
        all_X.extend([x1 + x2]), all_Y.append(y), all_Z.extend([z])

    res = []
    
    for sub_z_list in z_lists:
        X, Y, Z = [], [], []
        for x, y, z in zip(all_X, all_Y, all_Z):
            if all(elem in sub_z_list for elem in z):
                X.append(x)
                Y.append(y)
                Z.append(z)
        if len(X) < min_sequence_num:
            remain_len = min_sequence_num - len(X)
            mnistDataset = addition(n, dataset, train=train, sequence_num=remain_len, z_list=sub_z_list, digit_base=digit_base)
            for idx in range(len(mnistDataset)):
                x1, x2, y = mnistDataset[idx]
                z = mnistDataset._get_symbol_label(idx)
                X.extend([x1 + x2]), Y.append(y), Z.extend([z])
        if get_pseudo_label:
            res.append((X, Z, Y))
        else:
            res.append(X, None, Y)
        print(sub_z_list, len(X))
    return res


if __name__ == "__main__":
    mnist_add = get_data(n=10)
    print(len(mnist_add), len(mnist_add[0]))
    for i in range(3):
        print(mnist_add[1][i], mnist_add[2][i])