import torch
import numpy as np


def prepare_hetero_datasets(train_data, test_data, n_classes, n_workers, q, train_batch_size, pred_batch_size):
    train_data, test_data = to_same_n_data_per_class(train_data, n_classes), to_same_n_data_per_class(test_data, n_classes) # to make sure that each class has the same number of data
    train_data, test_data = randomize(train_data), randomize(test_data)
    train_loaders = mk_multi_loaders(dataset=train_data, n_workers=n_workers, batch_size=train_batch_size, n_classes=n_classes, q=q)
    pred_loader_on_train_data = torch.utils.data.DataLoader(train_data, batch_size=pred_batch_size, shuffle=False)
    pred_loader_on_test_data  = torch.utils.data.DataLoader(test_data,  batch_size=pred_batch_size, shuffle=False)
    return train_loaders, pred_loader_on_train_data, pred_loader_on_test_data


def split_list(lst, n):
    list_size = len(lst)
    a = list_size // n
    b = list_size % n
    return [lst[i*a + (i if i < b else b):(i+1)*a + (i+1 if i < b else b)]
            for i in range(n)]


def test_split_list():
    lst = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    n = 3
    assert split_list(lst, n) == [[1, 2, 3, 4], [5, 6, 7], [8, 9, 10]]

    
def random_split(dataset, n_classes, n_workers, q):
    test_split_list()
    if n_workers != n_classes:
        print("The number of workers needs to be the number of classes in the current implementation.")
        assert False
    classwise = {i: [] for i in range(n_classes)}
    for data in dataset:
        cls = data[1]
        classwise[cls].append(data)

    ans = {i: [] for i in range(n_workers)} 
    for cls in classwise.keys():
        n = len(classwise[cls])
        ans[cls] += classwise[cls][:int(q*n)]
        remains = classwise[cls][int(q*n):]
        remains_split = split_list(remains, n_workers-1)
        for other_cls in classwise.keys():
            if other_cls != cls:
                ans[other_cls] += remains_split.pop()
    return ans


def to_same_n_data_per_class(dataset, n_classes):
    classwise = {i: [] for i in range(n_classes)}
    for data in dataset:
        cls = data[1]
        classwise[cls].append(data)

    n = np.min([len(classwise[cls]) for cls in classwise.keys()])
    if n == np.max([len(classwise[cls]) for cls in classwise.keys()]):
        return dataset
    ans = []
    for cls in classwise.keys():
        ans += classwise[cls][:n]
    return ans


def randomize(dataset):
    n = len(dataset)
    I = np.random.choice(np.arange(n), n, replace=False)
    ans = [dataset[i] for i in I]
    return ans


def mk_multi_loaders(dataset, n_workers, batch_size, n_classes, q=1.0):
    n = len(dataset)
    if n_workers != 1:
        divided_dataset = random_split(dataset, n_classes, n_workers, q)
        multi_loaders = [torch.utils.data.DataLoader(divided_dataset[i],
                                                     batch_size=batch_size,
                                                     shuffle=True)
                         for i in range(n_workers)]
    else:
        multi_loaders = [torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)]
    return multi_loaders




