import torch
import numpy as np

from torch.utils.data import DataLoader, random_split

np.random.seed(42)


def get_FL_dataloader(dataset, num_clients, split_strategy="Uniform",
                      do_train=True, need_validation=True, batch_size=64,
                      do_shuffle=True, num_workers=0):
    if "Dirichlet" in split_strategy:
        # https://arxiv.org/pdf/2102.02079.pdf
        # https://github.com/Mangata1/NIID-Bench/blob/5371adbff98156793a413c7658923673b4aef7d7/utils.py#L179
        # Quantity Skew
        if split_strategy == "Dirichlet0.1":
            Dir = 0.1  # H -- heterogeneity
        elif split_strategy == "Dirichlet0.2":
            Dir = 0.2  # H -- heterogeneity
        elif split_strategy == "Dirichlet0.5":
            Dir = 0.5  # H -- heterogeneity
        elif split_strategy == "Dirichlet1":
            Dir = 1  # H -- heterogeneity
        elif split_strategy == "Dirichlet8":
            Dir = 8  # H -- heterogeneity
        elif split_strategy == "Dirichlet10":
            Dir = 10  # H -- heterogeneity
        elif split_strategy == "Dirichlet64":
            Dir = 64  # H -- heterogeneity
        else:
            Dir = 0.5  # H -- heterogeneity
        idxs = np.random.permutation(len(dataset))
        min_size = 0
        while min_size < 1:
            proportions = np.random.dirichlet(np.repeat(Dir, num_clients))
            proportions = proportions / proportions.sum()
            min_size = np.min(proportions * len(idxs))
        proportions = (np.cumsum(proportions) * len(idxs)).astype(int)[:-1]
        batch_idxs = np.split(idxs, proportions)
        if do_train:
            client_datasets = [torch.utils.data.Subset(dataset=dataset, indices=batch_idxs[i]) for i in range(num_clients)]
            if need_validation:
                # Split each partition into train/val and create DataLoader
                trainloaders = []
                valloaders = []
                for ds in client_datasets:
                    len_val = len(ds) // 10  # 10 % validation set
                    len_train = len(ds) - len_val
                    lengths = [len_train, len_val]
                    ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
                    trainloaders.append(
                        DataLoader(ds_train, batch_size=batch_size, shuffle=do_shuffle, num_workers=num_workers))
                    valloaders.append(DataLoader(ds_val, batch_size=batch_size, num_workers=num_workers))
                return trainloaders, valloaders, client_datasets
            else:
                trainloaders = []
                for ds in client_datasets:
                    trainloaders.append(
                        DataLoader(ds, batch_size=batch_size, shuffle=do_shuffle, num_workers=num_workers))
                return trainloaders, None, client_datasets
        else:
            testloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
            return testloader

    elif split_strategy == "SingleSensitive":
        H = 0.5
        idxs = np.random.permutation(len(dataset))
        min_size = 0
        while min_size < 1:
            proportions = np.random.dirichlet(np.repeat(H, num_clients))
            proportions = proportions / proportions.sum()
            min_size = np.min(proportions * len(idxs))
        proportions = (np.cumsum(proportions) * len(idxs)).astype(int)[:-1]
        batch_idxs = np.split(idxs, proportions)
        if do_train:
            client_datasets = [torch.utils.data.Subset(dataset=dataset, indices=batch_idxs[i]) for i in
                               range(num_clients)]
            if need_validation:
                # Split each partition into train/val and create DataLoader
                trainloaders = []
                valloaders = []
                for ds in client_datasets:
                    len_val = len(ds) // 10  # 10 % validation set
                    len_train = len(ds) - len_val
                    lengths = [len_train, len_val]
                    ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
                    trainloaders.append(
                        DataLoader(ds_train, batch_size=batch_size, shuffle=do_shuffle, num_workers=num_workers))
                    valloaders.append(DataLoader(ds_val, batch_size=batch_size, num_workers=num_workers))
                return trainloaders, valloaders, client_datasets
            else:
                trainloaders = []
                for ds in client_datasets:
                    trainloaders.append(
                        DataLoader(ds, batch_size=batch_size, shuffle=do_shuffle, num_workers=num_workers))
                return trainloaders, None, client_datasets
        else:
            testloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
            return testloader

    elif split_strategy == "Uniform":
        # Split training set into serval partitions to simulate the individual dataset
        partition_size = len(dataset) // num_clients
        lengths = [partition_size] * num_clients

        remainder = len(dataset) - (partition_size * num_clients)
        lengths[-1] += remainder

        if do_train:
            client_datasets = random_split(dataset, lengths, torch.Generator().manual_seed(42))
            if need_validation:
                # Split each partition into train/val and create DataLoader
                trainloaders = []
                valloaders = []
                for ds in client_datasets:
                    len_val = len(ds) // 10  # 10 % validation set
                    len_train = len(ds) - len_val
                    lengths = [len_train, len_val]
                    ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
                    trainloaders.append(DataLoader(ds_train, batch_size=batch_size, shuffle=do_shuffle, num_workers=num_workers))
                    valloaders.append(DataLoader(ds_val, batch_size=batch_size, num_workers=num_workers))
                return trainloaders, valloaders, client_datasets
            else:
                trainloaders = []
                for ds in client_datasets:
                    trainloaders.append(DataLoader(ds, batch_size=batch_size, shuffle=do_shuffle, num_workers=num_workers))
                return trainloaders, None, client_datasets
        else:
            testloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
            return testloader

    else:
        pass
