#! -*- coding: utf-8
import os.path as path
import typing

import numpy as np
import torch
import torchvision

__all__ = ["load_dataset", "distribute_dataset", "wrap_dataset"]


def load_dataset(name: str, *args, datadir: str = None, **kwargs) -> typing.Sequence[torch.utils.data.Dataset]:
    name = name.lower()

    # train, evalデータセットを返す
    if name == "cifar-10":
        transform = torchvision.transforms.Compose([
            torchvision.transforms.RandomCrop(32, padding=4),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.ToTensor(),
            # torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            torchvision.transforms.Normalize(
                (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            torchvision.transforms.RandomErasing(),
        ])
        trains = torchvision.datasets.CIFAR10(root=datadir, train=True, download=True,
                                              transform=transform)

        transform = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            # torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            torchvision.transforms.Normalize(
                (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        evals = torchvision.datasets.CIFAR10(root=datadir, train=False, download=True,
                                             transform=transform)
        return trains, evals
    elif name == "cifar-100":
        transform = torchvision.transforms.Compose([
            torchvision.transforms.RandomCrop(32, padding=4),
            torchvision.transforms.RandomHorizontalFlip(),
            # torchvision.transforms.RandomRotation(15),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                (0.5070751592371323, 0.48654887331495095, 0.4409178433670343),
                (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)),
            torchvision.transforms.RandomErasing(),
        ])
        trains = torchvision.datasets.CIFAR100(root=datadir, train=True, download=True,
                                               transform=transform)

        transform = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                (0.5070751592371323, 0.48654887331495095, 0.4409178433670343),
                (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)),
        ])
        evals = torchvision.datasets.CIFAR100(root=datadir, train=False, download=True,
                                              transform=transform)
        return trains, evals
    elif name == "mnist":
        transform = torchvision.transforms.ToTensor()
        # 28x28
        trains = torchvision.datasets.MNIST(root=datadir, train=True, download=True,
                                            transform=transform)
        tests = torchvision.datasets.MNIST(root=datadir, train=False, download=True,
                                           transform=transform)
        return trains, tests
    elif name == "fashion-mnist":
        transform = torchvision.transforms.ToTensor()
        trains = torchvision.datasets.FashionMNIST(root=datadir, train=True, download=True,
                                                   transform=transform)
        tests = torchvision.datasets.FashionMNIST(root=datadir, train=False, download=True,
                                                  transform=transform)
        return trains, tests

    elif name == "glue":
        from .glue import load_glue_dataset
        trains, evals = load_glue_dataset(
            name, *args, datadir=datadir, **kwargs)
        return trains, evals

    else:
        raise ValueError(f"Unsupported dataset: {name}")


def distribute_dataset(dataset: torch.utils.data.Dataset, ndistribute: int, method: str, seed: int = 17,
                       nduplicate: int = 1, **kwargs) -> typing.Sequence[torch.utils.data.Subset]:
    assert len(dataset) > 0
    assert ndistribute > 0
    assert nduplicate > 0
    rs = np.random.RandomState(seed)

    if method == "no":
        return [dataset for _ in range(ndistribute)]
    elif method == "even":
        indices = [[int(i) for i in idxs] for idxs  # dataset が np.int32 だとエラーになるのでpythonデフォルトのint32にキャスト
                   in np.array_split(rs.permutation(list(range(len(dataset)))*nduplicate), ndistribute)]
        return [torch.utils.data.Subset(dataset, idxs) for idxs in indices]
    elif method == "target_even":
        targets = {}
        for i, target in enumerate(dataset.targets if hasattr(dataset, "targets") else np.ndarray([0] * len(dataset))):
            if not target in targets:
                targets[target] = []
            targets[target].append(i)

        indices = [[] for _ in range(ndistribute)]
        for data_indices in targets.values():
            splited = np.array_split(data_indices*nduplicate, ndistribute)
            # データ数が分割数で割り切れない場合、先頭のノードから優先的に割り当てられる。
            # なるべくデータ数を均等にするため、ノードへの割り当て順をランダム化する
            for i, idxs in zip(rs.permutation(range(ndistribute)), splited):
                indices[i] = np.concatenate([indices[i], idxs])

        return [torch.utils.data.Subset(dataset, [int(i) for i in rs.permutation(idxs)]) for idxs in indices]

    elif method == "dirichlet":
        from .distributes import dirichlet
        targets = dataset.targets if hasattr(dataset, "targets") \
            else np.ndarray([0] * len(dataset))
        ndatas = len(targets)
        targets = np.concatenate([targets] * nduplicate)
        indices = dirichlet(targets, ndistribute, seed=seed, **kwargs)
        return [torch.utils.data.Subset(dataset, [int(i % ndatas) for i in idxs]) for idxs in indices]

    elif method == "random_sample":
        ndata = len(dataset)
        # kwargsから必要なパラメータを抽出
        nsample = kwargs.pop("nsample", None)
        if not isinstance(nsample, int) or nsample < 1:  # nsample未指定の場合、evenと同等
            nsample = int((ndata*nduplicate)//nduplicate)
        assert nsample > 0
        assert nsample * ndistribute <= ndata * nduplicate

        indices = np.concatenate([rs.permutation(np.arange(ndata))
                                  for _ in range(nduplicate)])

        return [torch.utils.data.Subset(dataset, idxs) for idxs in [indices[i*nsample:(i+1)*nsample] for i in range(ndistribute)]]

    elif method == "label_sample":
        # kwargsから必要なパラメータを抽出
        # label_sampleの場合、nsampleはラベルごとのサンプル数
        nsample = kwargs.pop("nsample", None)

        targets = {}
        for i, target in enumerate(dataset.targets if hasattr(dataset, "targets") else np.ndarray([0] * len(dataset))):
            if isinstance(target, torch.Tensor):
                target = target.detach().cpu().item()
            if not target in targets:
                targets[target] = []
            targets[target].append(i)

        min_target = np.min([len(idxs) for idxs in targets.values()])
        assert nsample * ndistribute <= min_target * nduplicate, \
            f"nsample={nsample}, ndistribute={ndistribute}, min target={min_target}, nduplicate={nduplicate}"

        targets = {target: np.concatenate([rs.permutation(idxs) for _ in range(nduplicate)])
                   for target, idxs in targets.items()}

        # indices = [np.concatenate([idxs[i*nsample:(i+1)*nsample]] for idxs in targets.values())
        #            for i in range(ndistribute)]
        indices = []
        for i in range(ndistribute):
            idxs = []
            for index_list in targets.values():
                idxs.extend(index_list[i*nsample:(i+1)*nsample])
            indices.append(idxs)

        return [torch.utils.data.Subset(dataset, [int(i) for i in rs.permutation(idxs)]) for idxs in indices]


def wrap_dataset(dataset: torch.utils.data.Dataset, **kwargs) -> torch.utils.data.Dataset:
    # datasetに対してラッパーを仕掛ける
    # TODO: 暫定
    return dataset
