#! -*- coding: utf-8
import itertools
import os.path as path
import typing

import numpy as np
import pandas as pd
import torch
import torchvision

__all__ = ["load_dataset", "distribute_dataset", "wrap_dataset"]


def load_dataset(name: str, *args, datadir: str = None, **kwargs) -> typing.Sequence[torch.utils.data.Dataset]:
    if name == "cifar-10":
        transform = torchvision.transforms.Compose([
            torchvision.transforms.RandomCrop(32, padding=4),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            torchvision.transforms.RandomErasing(),
        ])
        trains = torchvision.datasets.CIFAR10(root=datadir, train=True, download=True,
                                              transform=transform)

        transform = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        evals = torchvision.datasets.CIFAR10(root=datadir, train=False, download=True,
                                             transform=transform)
        return trains, evals
    elif name == "cifar-100":
        transform = torchvision.transforms.Compose([
            torchvision.transforms.RandomCrop(32, padding=4),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                (0.5070751592371323, 0.48654887331495095, 0.4409178433670343),
                (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)),
            torchvision.transforms.RandomErasing(),
        ])
        trains = torchvision.datasets.CIFAR100(root=datadir, train=True, download=True,
                                               transform=transform)

        transform = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                (0.5070751592371323, 0.48654887331495095, 0.4409178433670343),
                (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)),
        ])
        evals = torchvision.datasets.CIFAR100(root=datadir, train=False, download=True,
                                              transform=transform)
        return trains, evals

    elif name == "mnist":
        transform = torchvision.transforms.ToTensor()
        # 28x28
        trains = torchvision.datasets.MNIST(root=datadir, train=True, download=True,
                                            transform=transform)
        tests = torchvision.datasets.MNIST(root=datadir, train=False, download=True,
                                           transform=transform)
        return trains, tests
    elif name == "fashion-mnist":
        transform = torchvision.transforms.ToTensor()
        trains = torchvision.datasets.FashionMNIST(root=datadir, train=True, download=True,
                                                   transform=transform)
        tests = torchvision.datasets.FashionMNIST(root=datadir, train=False, download=True,
                                                  transform=transform)
        return trains, tests
    elif name == "emnist":
        # transform = torchvision.transforms.ToTensor()
        transform = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.1307, ), (0.3081, )),
        ])
        # 28x28
        # split="byclass"
        split = kwargs.get("split", "byclass")
        trains = torchvision.datasets.EMNIST(root=datadir, split=split, train=True, download=True,
                                             transform=transform)
        tests = torchvision.datasets.EMNIST(root=datadir, split=split, train=False, download=True,
                                            transform=transform)
        return trains, tests

    elif name == "femnist":
        from .femnist import FEMNIST
        trains = FEMNIST(datadir, train=True)
        tests = FEMNIST(datadir, train=False)
        return trains, tests

    elif name == "femnist_raw":
        from .femnist_raw import FEMNIST
        trains = FEMNIST(datadir, train=True, resize=(28, 28),
                         channel_first=True,)
        tests = FEMNIST(datadir, train=False, resize=(28, 28),
                        channel_first=True,)
        return trains, tests

    elif name == "digit-5":
        config = kwargs.get("digit-5", {"ntest": 0.2, "seed": 11})
        rs = np.random.RandomState(config.get("seed", 11))
        ntest = config.get("ntest", 0.2)

        from .digit5 import Digit5
        datas = Digit5(datadir)
        targets = datas.targets.detach().cpu().numpy()
        class_of_indices = {target: [] for target in range(10)}
        for i, target in enumerate(targets):
            class_of_indices[target].append(i)
        trains, tests = [], []
        for target, indices in class_of_indices.items():
            idx = int(len(indices) * ntest)
            test_indices, train_indices = np.split(rs.permutation(indices),
                                                   [idx])
            trains.extend(train_indices)
            tests.extend(test_indices)
        trains = torch.utils.data.Subset(datas, [int(i) for i in trains])
        tests = torch.utils.data.Subset(datas, [int(i) for i in tests])

        setattr(trains, "targets", [trains[i][1] for i in range(len(trains))])
        setattr(tests, "targets", [tests[i][1] for i in range(len(tests))])

        return trains, tests

    elif name == "digit-5_source":
        config = kwargs.get("digit-5", {"ntest": 0.2, "seed": 11})
        rs = np.random.RandomState(config.get("seed", 11))
        ntest = config.get("ntest", 0.2)

        from .digit5 import Digit5
        datas = Digit5(datadir)

        data_sources = np.unique(datas.sources)
        datalist = pd.DataFrame(dict(source=datas.sources,
                                     target=datas.targets.detach().cpu().numpy()))
        train_indices, test_indices = [], []
        train_sources, test_sources = [], []
        train_targets, test_targets = [], []
        for source, target in itertools.product(data_sources, range(10)):
            indices = list(datalist.loc[(datalist.source == source)
                                        & (datalist.target == target)].index)
            idx = int(len(indices) * ntest)
            test_idxs, train_idxs = np.split(rs.permutation(indices),
                                             [idx])
            train_indices.extend(train_idxs)
            train_sources.extend([source]*len(train_idxs))
            train_targets.extend([target]*len(train_idxs))

            test_indices.extend(test_idxs)
            test_sources.extend([source]*len(test_idxs))
            test_targets.extend([target]*len(test_idxs))

        trains = torch.utils.data.Subset(
            datas, [int(i) for i in train_indices])
        tests = torch.utils.data.Subset(datas, [int(i) for i in test_indices])

        setattr(trains, "targets", train_targets)
        setattr(tests, "targets", test_targets)
        setattr(trains, "sources", np.array(train_sources))
        setattr(tests, "sources", np.array(test_sources))

        return trains, tests

    elif name == "glue":
        from .glue import load_glue_dataset
        trains, evals = load_glue_dataset(
            name, *args, datadir=datadir, **kwargs)
        return trains, evals

    elif name == "20newsgroup":
        from .twenty_news import TwentyNewsgroupDataset
        trains = TwentyNewsgroupDataset(
            path.join(datadir, "20news.input_ids.train.txt"),
            path.join(datadir, "20news.targets.train.txt"),
            **kwargs)
        tests = TwentyNewsgroupDataset(
            path.join(datadir, "20news.input_ids.test.txt"),
            path.join(datadir, "20news.targets.test.txt"),
            **kwargs)
        return trains, tests
    elif name == "PTB":
        from .ptb_datasets import PTBDataset
        trains = PTBDataset(path.join(datadir, "ptb.train.tokenized.txt"),
                            **kwargs)
        tests = PTBDataset(path.join(datadir, "ptb.test.tokenized.txt"),
                           **kwargs)
        return trains, tests
    else:
        raise ValueError(f"Unsupported dataset: {name}")


def distribute_dataset(dataset: torch.utils.data.Dataset, ndistribute: int, method: str, seed: int = 17,
                       nduplicate: int = 1, **kwargs) -> typing.Sequence[torch.utils.data.Subset]:
    assert len(dataset) > 0
    assert ndistribute > 0
    assert nduplicate > 0

    rs = np.random.RandomState(seed)
    if method == "no":
        indices = np.concatenate([np.arange(len(dataset))]*nduplicate)
        return [torch.utils.data.Subset(dataset, [int(idx) for idx in indices]) for _ in range(ndistribute)]
    elif method == "even":
        idxs = [[int(i) for i in idxs] for idxs
                in np.array_split(rs.permutation(np.arange(len(dataset))), ndistribute)]
        return [torch.utils.data.Subset(dataset, idxs) for idxs in idxs]
    elif method == "class_assign":
        nassign = kwargs.get("nassign", 1)
        targets = dataset.targets if hasattr(dataset, "targets") \
            else np.ndarray([0] * len(dataset))
        labels = np.unique(targets)
        if len(labels) <= nassign:
            ndatas = len(targets)
            idxs = [int(i % ndatas)
                    for i in np.arange(len(targets)*nduplicate)]
            return [torch.utils.data.Subset(dataset, idxs) for i in range(ndistribute)]

        node_label_assigned = {i: [] for i in range(ndistribute)}
        label_node_assigned = {l: [] for l in labels}
        assign_labels = rs.permutation(labels)
        for node, assigned in node_label_assigned.items():
            assigned.extend(assign_labels[:nassign])
            assign_labels = assign_labels[nassign:]
            if len(assigned) < nassign:
                assign_labels = rs.permutation(
                    list(set(labels) - set(assigned)))
                selected = assign_labels[:nassign-len(assigned)]
                assign_labels = rs.permutation(np.concatenate([assign_labels[nassign-len(assigned):],
                                                              assigned]))
                assigned.extend(selected)

            for l in assigned:
                label_node_assigned[l].append(node)

        label_indices = {l: np.where(targets == l)[0] for l in labels}
        node_indices = {i: [] for i in range(ndistribute)}
        for label, nodes in label_node_assigned.items():
            if len(nodes) < 1:  # unassigned label.
                continue
            indices = rs.permutation(np.concatenate(
                [label_indices[label]]*nduplicate))
            for node_id, idxs in zip(nodes, np.array_split(indices, len(nodes))):
                node_indices[node_id].extend(idxs)
        return [torch.utils.data.Subset(dataset, [int(i) for i in rs.permutation(idxs)])
                for idxs in node_indices.values()]

    elif method == "dirichlet":
        from .distributes import dirichlet
        if hasattr(dataset, "sources"):  # for digit-5
            sources = np.unique(dataset.sources)
            dupindices = np.array_split(rs.permutation(np.arange(ndistribute)),
                                        len(sources))
            ret = {}
            for source, dupidxs in zip(sources, dupindices):
                indices = np.where(dataset.sources == source)[0]
                targets = np.array(dataset.targets)[indices]
                ds = torch.utils.data.Subset(dataset,
                                             [int(i) for i in indices])
                setattr(ds, "targets", targets)

                ndatas = len(targets)
                targets = np.concatenate([targets] * nduplicate)
                idxs = dirichlet(targets, len(dupidxs), seed=seed, **kwargs)
                for dupid, ids in zip(dupidxs, idxs):
                    ret[dupid] = torch.utils.data.Subset(
                        ds, [int(i % ndatas) for i in ids])

            return [ret[i] for i in np.arange(ndistribute)]

        targets = dataset.targets if hasattr(dataset, "targets") \
            else np.ndarray([0] * len(dataset))
        ndatas = len(targets)
        targets = np.concatenate([targets] * nduplicate)
        idxs = dirichlet(targets, ndistribute, seed=seed, **kwargs)
        return [torch.utils.data.Subset(dataset, [int(i % ndatas) for i in idxs]) for idxs in idxs]


def wrap_dataset(dataset: torch.utils.data.Dataset, **kwargs) -> torch.utils.data.Dataset:
    return dataset
