from tqdm import trange
import numpy as np
import random
import json
import os
import argparse
from torchvision.datasets import MNIST
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.data import Dataset, Subset


class CustomDataset(Dataset):
    def __init__(self, dataset1, dataset2):
        self.dataset1 = dataset1
        self.dataset2 = dataset2

    def __len__(self):
        return max(len(self.dataset1), len(self.dataset2))

    def __getitem__(self, idx):
        idx1 = idx % len(self.dataset1)
        idx2 = idx % len(self.dataset2)
        data1 = self.dataset1[idx1]
        data2 = self.dataset2[idx2]
        return data1, data2


def rearrange_data_by_class(data, targets, n_class):
    new_data = []
    for i in trange(n_class):
        idx = targets == i
        new_data.append(data[idx])
    return new_data

def get_dataset(dataset, n_sample, num_class=10):
    SRC_N_CLASS = num_class
    # full batch
    trainloader = DataLoader(dataset, batch_size=n_sample, shuffle=False)

    for _, xy in enumerate(trainloader, 0):
        dataset.data, dataset.targets = xy

    # print("Rearrange data by class...")
    data_by_class = rearrange_data_by_class(
        dataset.data.cpu().detach().numpy(),
        dataset.targets.cpu().detach().numpy(),
        SRC_N_CLASS
    )

    return data_by_class, n_sample, SRC_N_CLASS

def sample_class(SRC_N_CLASS, NUM_LABELS, user_id, label_random=False):
    assert NUM_LABELS <= SRC_N_CLASS
    if label_random:
        source_classes = [n for n in range(SRC_N_CLASS)]
        random.shuffle(source_classes)
        return source_classes[:NUM_LABELS]
    else:
        return [(user_id + j) % SRC_N_CLASS for j in range(NUM_LABELS)]

def devide_train_data(data, n_sample, SRC_CLASSES, NUM_USERS, alpha, max_diff_sample=5000, sampling_ratio=0.9):

    max_diff_size = 100000 # track minimal samples per user
    ###### Determine Sampling #######
    min_sample = 100000
    while max_diff_size > max_diff_sample or min_sample<100:
        # print("Try to find valid data separation")
        idx_batch=[{} for _ in range(NUM_USERS)]
        idx_batch_extra = {}
        samples_per_user = [0 for _ in range(NUM_USERS)]
        max_samples_per_user = sampling_ratio * n_sample / NUM_USERS
        for l in SRC_CLASSES:
            # get indices for all that label
            idx_l = [i for i in range(len(data[l]))]
            np.random.shuffle(idx_l)
            if sampling_ratio < 1:
                samples_for_l = int( min(max_samples_per_user, int(sampling_ratio * len(data[l]))) )
                # TODO: this is not the left idx
                idx_batch_extra[l] = idx_l[samples_for_l:]
                idx_l = idx_l[:samples_for_l]
                # print('label:', l, 'total len:', len(data[l]), 'sample len:', len(idx_l))
            # dirichlet sampling from this label
            proportions=np.random.dirichlet(np.repeat(alpha, NUM_USERS))
            # re-balance proportions
            proportions=np.array([p * (n_per_user < max_samples_per_user) for p, n_per_user in zip(proportions, samples_per_user)])
            proportions=proportions / proportions.sum()
            proportions=(np.cumsum(proportions) * len(idx_l)).astype(int)[:-1]
            # participate data of that label
            for u, new_idx in enumerate(np.split(idx_l, proportions)):
                # add new idex to the user
                idx_batch[u][l] = new_idx.tolist()
                samples_per_user[u] += len(idx_batch[u][l])
        max_diff_size = max(samples_per_user) - min(samples_per_user)
        min_sample = min(samples_per_user)
        # print("max_diff_size: ", max_diff_size)

    ###### CREATE USER DATA SPLIT #######
    X = [[] for _ in range(NUM_USERS)]
    y = [[] for _ in range(NUM_USERS)]
    X_extra = []
    Labels=[set() for _ in range(NUM_USERS)]
    print("processing users...")
    for u, user_idx_batch in enumerate(idx_batch):
        for l, indices in user_idx_batch.items():
            if len(indices) == 0: continue
            X[u] += data[l][indices].tolist()
            y[u] += (l * np.ones(len(indices))).tolist()
            Labels[u].add(l)

    for l, indices in idx_batch_extra.items():
        X_extra += data[l][indices].tolist()

    return X, y, Labels, idx_batch, samples_per_user, X_extra

def process_user_data(data, n_sample, SRC_CLASSES, num_users=5, alpha=0.1, Labels=None, unknown_test=0):
        X, y, Labels, idx_batch, samples_per_user, X_extra  = devide_train_data(
            data, n_sample, SRC_CLASSES, num_users, alpha)
        dataset={'users': [], 'user_data': {}, 'num_samples': [], 'unlabeled_data': X_extra}
        # print("#samples of Unlabeled data: ", len(X_extra))

            # 'x': torch.tensor(X[i], dtype=torch.float32),
            # 'y': torch.tensor(y[i], dtype=torch.int64)}

        for u in range(num_users):
            # print("{} samples in total".format(samples_per_user[u]))
            train_info = ''
            n_samples_for_u = 0
            for l in sorted(list(Labels[u])):
                n_samples_for_l = len(idx_batch[u][l])
                n_samples_for_u += n_samples_for_l
                train_info += "c={},n={}| ".format(l, n_samples_for_l)
            print(train_info)
            print("{} Labels/ {} Number of training samples for user [{}]:".format(len(Labels[u]), n_samples_for_u, u))
        return X, y
