import argparse
import json
import os
import shutil
import sys

import numpy as np

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../")))


def add_args(parser):
    parser.add_argument('--client_num_per_round', type=int, default=3, metavar='NN',
                        help='number of workers')
    parser.add_argument('--comm_round', type=int, default=10,
                        help='how many round of communications we should use')
    args = parser.parse_args()
    return args


def read_data(train_data_dir, test_data_dir):
    '''parses data in given train and test data directories

    assumes:
    - the data in the input directories are .json files with
        keys 'users' and 'user_data'
    - the set of train set users is the same as the set of test set users

    Return:
        clients: list of client ids
        groups: list of group ids; empty list if none found
        train_data: dictionary of train data
        test_data: dictionary of test data
    '''
    clients = []
    train_num_samples = []
    test_num_samples = []
    train_data = {}
    test_data = {}

    train_files = os.listdir(train_data_dir)
    train_files = [f for f in train_files if f.endswith('.json')]
    # print(train_files)
    for f in train_files:
        file_path = os.path.join(train_data_dir, f)
        with open(file_path, 'r') as inf:
            cdata = json.load(inf)
        clients.extend(cdata['users'])
        train_num_samples.extend(cdata['num_samples'])
        train_data.update(cdata['user_data'])
        # print(cdata['user_data'])
    test_files = os.listdir(test_data_dir)
    test_files = [f for f in test_files if f.endswith('.json')]
    for f in test_files:
        file_path = os.path.join(test_data_dir, f)
        with open(file_path, 'r') as inf:
            cdata = json.load(inf)
        test_num_samples.extend(cdata['num_samples'])
        test_data.update(cdata['user_data'])

    # parse python script input parameters
    parser = argparse.ArgumentParser()
    main_args = add_args(parser)

    class Args:
        def __init__(self, client_id, client_num_per_round, comm_round):
            self.client_num_per_round = client_num_per_round
            self.comm_round = comm_round
            self.client_id = client_id
            self.client_sample_list = []

    client_list = []
    for client_number in range(main_args.client_num_per_round):
        client_list.append(Args(client_number, main_args.client_num_per_round, main_args.comm_round))
    return clients, train_num_samples, test_num_samples, train_data, test_data, client_list


def client_sampling(round_idx, client_num_in_total, client_num_per_round):
    if client_num_in_total == client_num_per_round:
        client_indexes = [client_index for client_index in range(client_num_in_total)]
    else:
        num_clients = min(client_num_per_round, client_num_in_total)
        np.random.seed(round_idx)  # make sure for each comparison, we are selecting the same clients each round
        client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False)
    print("client_indexes = %s" % str(client_indexes))
    return client_indexes


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    main_args = add_args(parser)
    train_path = "../../FedML/data/MNIST/train"
    test_path = "../../FedML/data/MNIST/test"
    new_train = {}
    new_test = {}

    users, train_num_samples, test_num_samples, train_data, test_data, client_list = read_data(train_path, test_path)

    for round_idx in range(client_list[0].comm_round):
        sample_list = client_sampling(round_idx, 1000, main_args.client_num_per_round)
        for worker in client_list:
            worker.client_sample_list.append(sample_list[worker.client_id])
    os.mkdir('MNIST_mobile_zip')
    for worker in client_list:
        filetrain = 'MNIST_mobile/{}/train/train.json'.format(worker.client_id)
        os.makedirs(os.path.dirname(filetrain), mode=0o770, exist_ok=True)
        filetest = 'MNIST_mobile/{}/test/test.json'.format(worker.client_id)
        os.makedirs(os.path.dirname(filetest), mode=0o770, exist_ok=True)
        new_train['num_samples'] = [train_num_samples[i] for i in tuple(worker.client_sample_list)]
        new_train['users'] = [users[i] for i in tuple(worker.client_sample_list)]
        client_sample = new_train['users']
        new_train['user_data'] = {x: train_data[x] for x in client_sample}
        with open(filetrain, 'w') as fp:
            json.dump(new_train, fp)
        new_test['num_samples'] = [test_num_samples[i] for i in tuple(worker.client_sample_list)]
        new_test['users'] = [users[i] for i in tuple(worker.client_sample_list)]
        client_sample = new_test['users']
        new_test['user_data'] = {x: test_data[x] for x in client_sample}
        with open(filetest, 'w') as ff:
            json.dump(new_test, ff)
        shutil.make_archive('MNIST_mobile/{}'.format(worker.client_id), 'zip', 'MNIST_mobile', str(worker.client_id))
        shutil.move('MNIST_mobile/{}.zip'.format(worker.client_id), 'MNIST_mobile_zip')
