import json
import logging
import os
import csv
import random
import numpy as np
import torch
import h5py
import pandas as pd
import numpy as np
import copy
from sklearn.model_selection import train_test_split
from sklearn.utils import resample
import random
from itertools import combinations
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def read_data(dataset, data_dir,client_num_in_total,users_per_client,items_per_client):
    set_seed(42)
    if dataset == 'movielens100k':
        data = pd.read_csv(os.path.join(data_dir, 'u.data'), sep='\t', names=['uid', 'iid', 'y', 'timestamp'])
        data = data.drop('timestamp', axis=1)
        data['uid'] = data['uid'] - 1
        data['iid'] = data['iid'] - 1
        rating_counts = data['y'].value_counts(normalize=True).sort_index()
    elif dataset == 'movielens1m':
        data = pd.read_csv(os.path.join(data_dir, 'ratings.dat'), sep='::', names=['uid', 'iid', 'y', 'timestamp'])
        data = data.drop('timestamp', axis=1)
        data['uid'] = data['uid'] - 1
        data['iid'] = data['iid'] - 1
        num_classes = 5
        equal_proportion = 1 / num_classes
        rating_counts = pd.Series([equal_proportion] * num_classes, index=range(1, num_classes + 1))


    n = len(data.index)
    random_integers = np.random.randint(0, 10, n)
    data['labels'] = random_integers

    num_clients = client_num_in_total
    users = data['uid'].unique()
    items = data['iid'].unique()

    initial_users_per_client = users_per_client
    initial_items_per_client = items_per_client

    clients = {i: {'users': set(), 'items': set(), 'data': pd.DataFrame(), 'train_data': pd.DataFrame(), 'test_data': pd.DataFrame()} for i in range(num_clients)}

    all_users = set(users)
    all_items = set(items)
    used_users = set()
    used_items = set()

    for client_id in clients:
        set_seed(42)
        available_users = list(all_users - used_users)
        available_items = list(all_items - used_items)
        clients[client_id]['users'] = set(random.sample(available_users, initial_users_per_client))
        clients[client_id]['items'] = set(random.sample(available_items, initial_items_per_client))
        clients[client_id]['data'] = data[
            (data['uid'].isin(clients[client_id]['users'])) & (data['iid'].isin(clients[client_id]['items']))]
        used_users.update(clients[client_id]['users'])
        used_items.update(clients[client_id]['items'])

    remaining_users = all_users - used_users
    remaining_items = all_items - used_items

    remain_data = data[
        (data['uid'].isin(remaining_users)) & (data['iid'].isin(remaining_items))]

    remaining_data = [remaining_users,remaining_items,remain_data]

    return data, clients, remaining_data, rating_counts

def batch_data(data, batch_size):
    data_x = data[['uid', 'iid']]
    data_y = data['y']

    data_x = np.array(data_x)
    data_y = np.array(data_y)
    set_seed(42)
    rng_state = np.random.get_state()
    np.random.shuffle(data_x)
    np.random.set_state(rng_state)
    np.random.shuffle(data_y)
    batch_data = list()
    for i in range(0, len(data_x), batch_size):
        batched_x = data_x[i:i + batch_size]
        batched_y = data_y[i:i + batch_size]
        batched_x = torch.from_numpy(np.asarray(batched_x)).long()
        batched_y = torch.from_numpy(np.asarray(batched_y)).float()
        batch_data.append((batched_x, batched_y))
    return batch_data

def non_iid_partition_with_dirichlet_distribution(label_list,
                                                  client_num,
                                                  classes,
                                                  alpha,
                                                  task='classification'):
    set_seed(42)
    net_dataidx_map = {}
    K = classes
    N = len(label_list)

    # guarantee the minimum number of sample in each client
    min_size = 0
    while min_size < 10:
        # logging.debug("min_size = {}".format(min_size))
        idx_batch = [[] for _ in range(client_num)]

        for k in range(K):
            # get a list of batch indexes which are belong to label k
            idx_k = np.where(label_list == k)[0]
            idx_batch, min_size = partition_class_samples_with_dirichlet_distribution(N, alpha, client_num,
                                                                                      idx_batch, idx_k)
    for i in range(client_num):
        np.random.shuffle(idx_batch[i])
        net_dataidx_map[i] = idx_batch[i]

    return net_dataidx_map

def partition_class_samples_with_dirichlet_distribution(N, alpha, client_num, idx_batch, idx_k):
    set_seed(42)
    np.random.shuffle(idx_k)
    # using dirichlet distribution to determine the unbalanced proportion for each client (client_num in total)
    # e.g., when client_num = 4, proportions = [0.29543505 0.38414498 0.31998781 0.00043216], sum(proportions) = 1
    proportions = np.random.dirichlet(np.repeat(alpha, client_num))

    # get the index in idx_k according to the dirichlet distribution
    proportions = np.array([p * (len(idx_j) < N / client_num) for p, idx_j in zip(proportions, idx_batch)])
    proportions = proportions / proportions.sum()
    proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]

    # generate the batch list for each client
    idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
    min_size = min([len(idx_j) for idx_j in idx_batch])

    return idx_batch, min_size

def noniid_merge_data_with_dirichlet_distribution(clients, client_num_in_total, remaining_data, alpha, class_num=10):
    label_list = np.asarray(remaining_data[2]["labels"])
    idx_map = non_iid_partition_with_dirichlet_distribution(label_list, client_num_in_total, class_num,
                                                                  alpha)
    print(type(idx_map))

    for client_id, indices in idx_map.items():
        client_data = remaining_data[2].iloc[indices]
        clients[client_id]['data'] = pd.concat([clients[client_id]['data'], client_data])
        clients[client_id]['users'].update(client_data['uid'].unique())
        clients[client_id]['items'].update(client_data['iid'].unique())

    for client_id in clients:
        clients[client_id]['train_data'], clients[client_id]['test_data'] = train_test_split(clients[client_id]['data'], test_size=0.2, random_state=42)

    return clients

def find_common_uids(data_list, client_num_in_total):
    common_uids = set(data_list[0]['data']['uid'])

    for i in range(1, client_num_in_total):
        df = data_list[i]['data']
        common_uids &= set(df['uid'])

    return common_uids

def load_partition_data(batch_size,
                        client_num_in_total,
                        users_per_client,
                        items_per_client,
                        alpha,
                        data_dir,
                        dataset):
    train_data_num = 0
    test_data_num = 0
    train_data_local_dict = dict()
    test_data_local_dict = dict()
    train_data_local_num_dict = dict()
    train_data_global = list()
    test_data_global = list()
    distillation_data_global = list()
    client_idx = 0

    data, clients, remaining_data, rating_counts = read_data(dataset, data_dir,client_num_in_total,users_per_client,items_per_client)

    clients= noniid_merge_data_with_dirichlet_distribution(clients, client_num_in_total,remaining_data ,alpha)

    logging.info("loading data...")

    for client_id in range(client_num_in_total):
        user_train_data_num = len(clients[client_id]['train_data'])
        user_test_data_num = len(clients[client_id]['test_data'])
        train_data_num += user_train_data_num
        test_data_num += user_test_data_num
        train_data_local_num_dict[client_id] = user_train_data_num

        # transform to batches
        train_batch = batch_data(clients[client_id]['train_data'], batch_size)
        test_batch = batch_data(clients[client_id]['test_data'], batch_size)

        train_data_local_dict[client_id] = train_batch
        test_data_local_dict[client_id] = test_batch
        train_data_global += train_batch
        test_data_global += test_batch
        client_idx += 1

    class_num = 10

    co_uid = find_common_uids(clients, client_num_in_total)

    return rating_counts, co_uid, clients, train_data_num, test_data_num, train_data_global, test_data_global, \
            train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num

def mapper(src, tgt, co_uid_all):
    co_uid_set = set(co_uid_all)

    src_co_uid = src[src.uid.isin(co_uid_set)]
    src_non_co_uid = src[~src.uid.isin(co_uid_set)]

    tgt_co_uid = tgt[tgt.uid.isin(co_uid_set)]
    tgt_non_co_uid = tgt[~tgt.uid.isin(co_uid_set)]

    src = pd.concat([src_co_uid, src_non_co_uid], ignore_index=True)
    tgt = pd.concat([tgt_co_uid, tgt_non_co_uid], ignore_index=True)

    print('All uid: {}, Co uid: {}.'.format(len(set(src.uid) | set(tgt.uid)), len(co_uid_all)))

    return src, tgt, co_uid_all, len(co_uid_all)

def deal_test(test, rating_counts1):
    rating_counts = test['y'].value_counts()

    min_count = rating_counts.min()

    resampled_tests = []
    for rating in rating_counts.index:
        rating_test = test[test['y'] == rating]

        resampled_test = rating_test.sample(n=min_count, random_state=42)
        resampled_tests.append(resampled_test)

    balanced_test = pd.concat(resampled_tests)

    uids = test['uid'].unique()
    missing_uids = set(uids) - set(balanced_test['uid'].unique())

    for uid in missing_uids:
        uid_test = test[test['uid'] == uid]
        balanced_test = pd.concat([balanced_test, uid_test.sample(n=1, random_state=42)])

    balanced_test = balanced_test.sample(frac=1, random_state=42).reset_index(drop=True)
    return balanced_test

def split(src, tgt, co_uid, co_user_num, rating_counts):
    set_seed(42)
    print('All iid: {}.'.format(len(set(src.iid) | set(tgt.iid))))
    tgt_users = set(tgt.uid.unique())
    co_users = co_uid
    test_users = set(random.sample(co_users, round(0.2 * co_user_num)))
    train_src = src
    train_tgt = tgt[tgt['uid'].isin(tgt_users - test_users)]
    test = tgt[tgt['uid'].isin(test_users)]
    test = deal_test(test, rating_counts)
    train_meta = tgt[tgt['uid'].isin(co_users - test_users)]
    return train_src, train_tgt, test, train_meta, list(test_users), list(co_users)

def save(dataset, client_num, client_a, client_b, train_src, train_tgt, test, train_meta, test_list, co_user_list, alpha):
    output_root = '../data/' + dataset + '/ready/_' + str(client_num) + \
                  '_/tgtclient_' + str(client_b) + '_srcclient_' + str(client_a)
    if not os.path.exists(output_root):
        os.makedirs(output_root)

    columns = ['uid', 'iid', 'y']
    train_src = train_src.iloc[:, :3]
    train_src.columns = columns
    train_tgt = train_tgt.iloc[:, :3]
    train_tgt.columns = columns
    test = test.iloc[:, :3]
    test.columns = columns
    train_meta = train_meta.iloc[:, :3]
    train_meta.columns = columns

    train_src.to_csv(output_root + '/train_src.csv', sep=',', header=None, index=False)
    train_tgt.to_csv(output_root + '/train_tgt.csv', sep=',', header=None, index=False)
    test.to_csv(output_root + '/test.csv', sep=',', header=None, index=False)
    train_meta.to_csv(output_root + '/train_meta.csv', sep=',', header=None, index=False)
    with open(output_root + '/test_list.csv', 'w') as f:
        writer = csv.writer(f)
        writer.writerow(test_list)
    with open(output_root + '/co_user_list.csv', 'w') as f:
        writer = csv.writer(f)
        writer.writerow(co_user_list)

def process_clients(dataset, co_uid_all, clients, rating_counts, alpha):
    client_ids = list(clients.keys())

    client_num = len(client_ids)

    for i in range(len(client_ids)):
        for j in range(len(client_ids)):
            if i != j:
                client_a = client_ids[i]
                client_b = client_ids[j]

                src1 = copy.deepcopy(clients[client_a]['data'])
                tgt1 = copy.deepcopy(clients[client_b]['data'])

                src, tgt, co_uid, co_uid_num = mapper(src1, tgt1, co_uid_all)
                train_src, train_tgt, test, train_meta, test_list, co_user_list = split(src, tgt, co_uid, co_uid_num, rating_counts)
                save(dataset, client_num, client_a, client_b, train_src, train_tgt, test, train_meta, test_list, co_user_list, alpha)

def load_data1(batch_size,
              client_num_in_total,
              users_per_client,
              items_per_client,
              alpha,
              data_dir,
               dataset):
    (rating_counts, co_uid, clients, train_data_num, test_data_num, train_data_global, test_data_global,
     train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num) = load_partition_data(batch_size,
                        client_num_in_total,
                        users_per_client,
                        items_per_client,
                        alpha,
                        data_dir,
                        dataset)

    process_clients(dataset, co_uid, clients, rating_counts, alpha)

    return rating_counts, co_uid, clients, train_data_num, test_data_num, train_data_global, test_data_global, \
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num
