import os
from typing import List

import numpy as np
import torch
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import TensorDataset
from torchvision import datasets, transforms


import numpy as np
from torchvision.datasets import CIFAR100


class CIFAR100Coarse(CIFAR100):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(CIFAR100Coarse, self).__init__(root, train, transform, target_transform, download)

        # update labels
        coarse_labels = np.array(
            [
                4,
                1,
                14,
                8,
                0,
                6,
                7,
                7,
                18,
                3,
                3,
                14,
                9,
                18,
                7,
                11,
                3,
                9,
                7,
                11,
                6,
                11,
                5,
                10,
                7,
                6,
                13,
                15,
                3,
                15,
                0,
                11,
                1,
                10,
                12,
                14,
                16,
                9,
                11,
                5,
                5,
                19,
                8,
                8,
                15,
                13,
                14,
                17,
                18,
                10,
                16,
                4,
                17,
                4,
                2,
                0,
                17,
                4,
                18,
                17,
                10,
                3,
                2,
                12,
                12,
                16,
                12,
                1,
                9,
                19,
                2,
                10,
                0,
                1,
                16,
                12,
                9,
                13,
                15,
                13,
                16,
                19,
                2,
                4,
                6,
                19,
                5,
                5,
                8,
                19,
                18,
                1,
                2,
                15,
                6,
                0,
                17,
                8,
                14,
                13,
            ]
        )
        self.targets = coarse_labels[self.targets]

        # update classes
        self.classes = [
            ["beaver", "dolphin", "otter", "seal", "whale"],
            ["aquarium_fish", "flatfish", "ray", "shark", "trout"],
            ["orchid", "poppy", "rose", "sunflower", "tulip"],
            ["bottle", "bowl", "can", "cup", "plate"],
            ["apple", "mushroom", "orange", "pear", "sweet_pepper"],
            ["clock", "keyboard", "lamp", "telephone", "television"],
            ["bed", "chair", "couch", "table", "wardrobe"],
            ["bee", "beetle", "butterfly", "caterpillar", "cockroach"],
            ["bear", "leopard", "lion", "tiger", "wolf"],
            ["bridge", "castle", "house", "road", "skyscraper"],
            ["cloud", "forest", "mountain", "plain", "sea"],
            ["camel", "cattle", "chimpanzee", "elephant", "kangaroo"],
            ["fox", "porcupine", "possum", "raccoon", "skunk"],
            ["crab", "lobster", "snail", "spider", "worm"],
            ["baby", "boy", "girl", "man", "woman"],
            ["crocodile", "dinosaur", "lizard", "snake", "turtle"],
            ["hamster", "mouse", "rabbit", "shrew", "squirrel"],
            ["maple_tree", "oak_tree", "palm_tree", "pine_tree", "willow_tree"],
            ["bicycle", "bus", "motorcycle", "pickup_truck", "train"],
            ["lawn_mower", "rocket", "streetcar", "tank", "tractor"],
        ]


def load_dset(path_dset, dset_name):
    # create dir for the dset
    if not os.path.isdir(path_dset):
        os.makedirs(path_dset)

    if dset_name == "mnist":
        trans_train = transforms.Compose(
            [
                transforms.RandomRotation(15),
                transforms.RandomResizedCrop(28, scale=(0.8, 1.0)),
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )
        trans_test = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )
        dset_train = datasets.MNIST(path_dset, train=True, download=True, transform=trans_train)
        dset_test = datasets.MNIST(path_dset, train=False, download=True, transform=trans_test)
    elif dset_name == "cifar10":
        trans_train = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        trans_test = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    # (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
                    (0.5, 0.5, 0.5),
                    (0.5, 0.5, 0.5),
                ),
            ]
        )
        dset_train = datasets.CIFAR10(path_dset, train=True, download=True, transform=trans_train)
        dset_test = datasets.CIFAR10(path_dset, train=False, download=True, transform=trans_test)
    elif dset_name[:8] == "cifar100":
        trans_train = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(
                    # (0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)
                    (0.5, 0.5, 0.5),
                    (0.5, 0.5, 0.5),
                ),
            ]
        )
        trans_test = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    # (0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)
                    (0.5, 0.5, 0.5),
                    (0.5, 0.5, 0.5),
                ),
            ]
        )
        if dset_name == "cifar100Coarse":
            dataset = CIFAR100Coarse
        else:
            dataset = datasets.CIFAR100
        dset_train = dataset(path_dset, train=True, download=True, transform=trans_train)
        dset_test = dataset(path_dset, train=False, download=True, transform=trans_test)
    elif dset_name == "fashionmnist":
        trans_train = transforms.Compose(
            [
                transforms.RandomRotation(15),
                transforms.RandomResizedCrop(28, scale=(0.8, 1.0)),
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
            ]
        )
        trans_test = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
            ]
        )
        dset_train = datasets.FashionMNIST(path_dset, train=True, download=False, transform=trans_train)
        dset_test = datasets.FashionMNIST(path_dset, train=False, download=False, transform=trans_test)
    elif dset_name == "iris":
        data, target = load_iris(return_X_y=True)
        scaler = StandardScaler()
        data = scaler.fit_transform(data)
        x_train, x_test, y_train, y_test = train_test_split(
            data, target, test_size=0.2, shuffle=True, random_state=0, stratify=target
        )
        x_train = torch.tensor(x_train, dtype=torch.float)
        x_test = torch.tensor(x_test, dtype=torch.float)
        y_train = torch.tensor(y_train, dtype=torch.long)
        y_test = torch.tensor(y_test, dtype=torch.long)

        dset_train = TensorDataset(x_train, y_train)
        dset_test = TensorDataset(x_test, y_test)
        dset_train.data = x_train
        dset_test.data = x_test
        dset_train.targets = y_train
        dset_test.targets = y_test
    elif dset_name == "stl10":
        trans_train = transforms.Compose(
            [
                transforms.Resize((32, 32)),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )

        dset_train = datasets.STL10(
            root=path_dset,
            split="unlabeled",
            folds=None,
            transform=trans_train,
            download=True,
        )
        dset_train.targets = np.random.permutation(list(range(10)) * 10000)
        dset_test = []
        return dset_train, dset_test
    elif dset_name == "imagenet100":
        trans_train = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        dset_train = datasets.ImageFolder(root=os.path.join(path_dset, "train"), transform=trans_train)
        trans_test = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        dset_test = datasets.ImageFolder(root=os.path.join(path_dset, "val"), transform=trans_test)
        dset_train.targets = np.array(dset_train.targets)
        dset_test.targets = np.array(dset_test.targets)
        return dset_train, dset_test
    elif dset_name == "tinyimagenet":
        trans_train = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        dset_train = datasets.ImageFolder(root=os.path.join(path_dset, "train_resize/box"), transform=trans_train)
        trans_test = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        dset_test = datasets.ImageFolder(root=os.path.join(path_dset, "val_resize"), transform=trans_test)
        dset_train.targets = np.array(dset_train.targets)
        dset_test.targets = np.array(dset_test.targets)
        return dset_train, dset_test

    else:
        raise ValueError(f"Got {dset_name=}")

    dset_train.targets = np.array(dset_train.targets)
    dset_test.targets = np.array(dset_test.targets)

    return dset_train, dset_test


def split_dset_iid(targets: np.array, ratios: List):
    assert np.isclose(sum(ratios), 1, atol=1e-8)
    n_cls = len(np.unique(targets))
    idxs = [np.array([], dtype=np.int64) for _ in range(len(ratios))]
    for cls in range(n_cls):
        cls_targets_perm = np.random.permutation(np.where(targets == cls)[0])
        tot = len(cls_targets_perm)
        start = 0
        for c, ratio in enumerate(ratios):
            end = start + int(tot * ratio)
            if c != len(ratios) - 1:
                idxs[c] = np.append(idxs[c], cls_targets_perm[start:end])
            else:
                idxs[c] = np.append(idxs[c], cls_targets_perm[start:])
            start = end
    return idxs


def split_dset_diri(targets, n_clients, alpha, double_stochstic=True):
    """Splits data among the clients according to a dirichlet distribution with parameter alpha"""

    if isinstance(targets, torch.Tensor):
        targets = targets.numpy()
    n_classes = np.max(targets) + 1
    label_distribution = np.random.dirichlet([alpha] * n_clients, n_classes)

    if double_stochstic:
        label_distribution = make_double_stochstic(label_distribution)

    class_idcs = [np.argwhere(np.array(targets) == y).flatten() for y in range(n_classes)]

    client_idcs = [[] for _ in range(n_clients)]
    for c, fracs in zip(class_idcs, label_distribution):
        for i, idcs in enumerate(np.split(c, (np.cumsum(fracs)[:-1] * len(c)).astype(int))):
            client_idcs[i] += [idcs]

    client_idcs = [np.concatenate(idcs) for idcs in client_idcs]

    print_split(client_idcs, targets)

    return client_idcs


def print_split(idcs, labels):
    n_labels = np.max(labels) + 1
    print("Data split:")
    splits = []
    for i, idccs in enumerate(idcs):
        split = np.sum(
            np.array(labels)[idccs].reshape(1, -1) == np.arange(n_labels).reshape(-1, 1),
            axis=1,
        )
        splits += [split]
        if len(idcs) < 30 or i < 10 or i > len(idcs) - 10:
            print(f" - Client {i}: {str(split):55} -> sum={np.sum(split)}", flush=True)
        elif i == len(idcs) - 10:
            print(".  " * 10 + "\n" + ".  " * 10 + "\n" + ".  " * 10)

    print(f" - Total:     {np.stack(splits, axis=0).sum(axis=0)}")
    print()


def make_double_stochstic(x):
    rsum = None
    csum = None

    n = 0
    while n < 1000 and (np.any(rsum != 1) or np.any(csum != 1)):
        x /= x.sum(0)
        x = x / x.sum(1)[:, np.newaxis]
        rsum = x.sum(1)
        csum = x.sum(0)
        n += 1

    return x
