import random
import tqdm
import numpy as np
from PIL import Image
from sklearn.cluster import KMeans

from easydict import EasyDict
import kornia
import torch

class ClientDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]


def sample_by_label(dataset, N, n_class=10, in_class=2, out_rate=0.1):
    # assign center labels for clients
    client_labels = [random.sample(range(n_class), k=in_class) for _ in range(N)]
    # ratios of samples from selected labels
    ratios = np.random.rand((n_class)) * out_rate
    # number of random sampmles:
    n_random_draw = (ratios * len(dataset)).astype(int)

    data_by_label = {}
    client_datasets = [[] for _ in range(N)]

    for (x, y) in dataset:
        if y not in data_by_label:
            data_by_label[y] = []
        data_by_label[y].append((x, y))
    for client, client_label in enumerate(client_labels):
        for label in client_label:
            client_datasets[client] += data_by_label[label]
        # print(n_random_draw[client])
        client_datasets[client] += random.sample(dataset, n_random_draw[client])
    datasets = [ClientDataset(client_dataset) 
                    for client_dataset in client_datasets]
    return datasets


def sample_by_only_positive(dataset, n_class=10):
    print("For simplicity, the number of client is equal to number of classes")
    N = n_class
    client_labels = list(range(n_class))
    data_by_label = {}
    client_datasets = [[] for _ in range(N)]
    for (x, y) in dataset:
        if y not in data_by_label:
            data_by_label[y] = []
        data_by_label[y].append((x, y))
    for client, client_label in enumerate(client_labels):
        client_datasets[client] += data_by_label[client_label]
    datasets = [ClientDataset(client_dataset) 
                    for client_dataset in client_datasets]
    return datasets


def sample_by_feature(dataset, N, n_worker=10, out_rate=0.1):
    """Split the dataset by HSV features"""
    features = [x for (x, _) in tqdm.tqdm(dataset)]
    features = torch.stack(features)
    # hsv_features = kornia.color.rgb_to_hsv(features).view(-1, 3*32*32)
    rgb_features = features.view(-1, 3*32*32)
    # features = torch.cat([rgb_features, hsv_features], dim=1)
    features = rgb_features
    kmeans = KMeans(n_clusters=N, n_jobs=n_worker)
    centers = kmeans.fit_predict(features)
    client_datasets = [[] for _ in range(N)]
    ratios = np.random.rand((N)) * out_rate
    n_random_draw = (ratios * len(dataset)).astype(int)

    for center, sample in zip(centers, dataset):
        client_datasets[center].append(sample)
    datasets = [ClientDataset(client_dataset + random.sample(dataset, n_random_draw[center])) 
                    for client_dataset in client_datasets]
    return datasets


def split_dataset(dataset, style, N):
    random.shuffle(dataset)
    if style == "i":
        datasets = []
        step = len(dataset) // N
        for i in range(0, len(dataset), step):
            # print(i, i+step, N)
            if (len(dataset)-i) < step / 3: 
                continue # drop too small split
            datasets.append(ClientDataset(dataset[i:i+step]))
        assert N == len(datasets)
    elif style == 'l':
        return sample_by_label(dataset, N)
    elif style == 'f':
        return sample_by_feature(dataset, N)
    elif style == 'p':
        return sample_by_only_positive(dataset)
    else:
        raise NotImplementedError
    return datasets


def process_data(args, dataset):   
    client_datasets = split_dataset(dataset, args.federated_style, args.N)
    client_lens = [len(i) for i in client_datasets]

    dataloaders = [
        torch.utils.data.DataLoader(client_dataset, batch_size=args.batch_size, 
                                    shuffle=True, num_workers=2)
        for client_dataset in client_datasets
    ]

    client_info = EasyDict(client_datasets=client_datasets, 
                           client_lens=client_lens)
    
    return client_info, dataloaders
