import os
import pickle
from pathlib import Path
import numpy as np
import torch
import json
import copy
import random
from sklearn.cluster import DBSCAN, KMeans
# from torch.utils.data import WeightedDatasetSampler
PROJECT_DIR = Path(__file__).parent.parent.absolute()

dataset = 'domain/global/transformer/e500_lr0.01_div_1'
client_num = 60
sample_num_per_round = 6 # 6 or 10
order_type = 'greedy'
# order_type = 'group'
# order_type = 'global'
index_dim = 1024
half_dim = 512
sim_metrics = 'mean'
reverse = False
group = 'mean'
tau = 1.0
dataset_name = 'domain'


def get_dataset_size():

    with open(PROJECT_DIR / "data" / dataset_name / "args.json", "r") as f:
        dataset_args = json.load(f)

        # get client party info
        try:
            partition_path = PROJECT_DIR / "data" / dataset_name / "partition.pkl"
            with open(partition_path, "rb") as f:
                partition = pickle.load(f)
        except:
            raise FileNotFoundError(f"Please partition {dataset_name} first.")
        # train_clients: List[int] = partition["separation"]["train"]
        # test_clients: List[int] = partition["separation"]["test"]
        # client_num: int = partition["separation"]["total"]
        clients_indices = partition["data_indices"]
        print(len(clients_indices))
        sizes = []
        for i in range(client_num):
            sizes.append(len(clients_indices[i]['train']))

        return sizes

sizes = get_dataset_size()
print(sizes)


def cos_sim(x, y):
    if not isinstance(x, torch.Tensor):
        x, y = torch.tensor(x), torch.tensor(y)
    sim_1 = torch.nn.functional.cosine_similarity(x[:half_dim], y[:half_dim], dim=0)
    sim_2 = torch.nn.functional.cosine_similarity(x[half_dim:], y[half_dim:], dim=0)
    if sim_metrics == 'mean':
        return sum([sim_1, sim_2]) / 2
    elif sim_metrics == 'min':
        return min(sim_1, sim_2)
    
    
def distance_in_groups(group_indices, indexs):
    if group == 'min':
        sim = 1.0
    elif group == 'max':
        sim = -1.0
    else:
        sim = 0.0
    count = 0
    for i in group_indices:
        for j in group_indices:
            if i != j:
                if group == 'min':
                    sim = min(cos_sim(indexs[int(i)], indexs[int(j)]), sim)
                elif group == 'max':
                    sim = max(cos_sim(indexs[int(i)], indexs[int(j)]), sim)
                else:
                    sim += sizes[i] * cos_sim(indexs[int(i)], indexs[int(j)])
                    count += sizes[i]

    return sim / count if count > 0 else sim

def distance_between_groups(group_1_indices, group_2_indices, indexs):

    if group == 'min':
        sim = 1.0
    elif group == 'max':
        sim = -1.0
    else:
        sim = 0.0
    count = 0
    for i in group_1_indices:
        for j in group_2_indices:
            if group == 'min':
                sim = min(cos_sim(indexs[int(i)], indexs[int(j)]), sim)
            elif group == 'max':
                sim = max(cos_sim(indexs[int(i)], indexs[int(j)]), sim)
            else:
                sim += sizes[i] * cos_sim(indexs[int(i)], indexs[int(j)])
                count += sizes[i]
    return sim / count if count > 0 else sim

# def clustering(indexs, round_sample_num):
#     clustering = KMeans(n_clusters=round_sample_num, random_state=0).fit()

# def cluster_search(indexs, round_sample_num, num_epoch):
#     clients_indexs = [list(indexs[i]) for i in range(client_num)]
#     cluster = KMeans(n_clusters=6, metric=cos_sim).fit(clients_indexs)
#     print(cluster.labels_)

def group_search(indexs, round_sample_num, num_epoch):
    group_num = client_num // round_sample_num
    group_client_indicies = torch.tensor([i for i in range(client_num)]).reshape(round_sample_num, group_num)
    group_map = {}
    for i in range(round_sample_num):
        for j in range(group_num):
            group_map[int(group_client_indicies[i][j])] = [i, j]
    
    group_average_sims = torch.tensor([distance_in_groups(group_client_indicies[i], indexs) for i in range(round_sample_num)])

    while True:
        new_group_average_sims = copy.deepcopy(group_average_sims)
        for client_i in range(client_num):
            for client_j in range(client_num):
                if int(group_map[client_i][0]) != int(group_map[client_j][0]):
                    group_client_indicies[group_map[client_i][0]][group_map[client_i][1]] = client_j
                    group_client_indicies[group_map[client_j][0]][group_map[client_j][1]] = client_i
                    new_sims = torch.tensor([distance_in_groups(group_client_indicies[i], indexs) for i in range(round_sample_num)])
                    if group == 'min' or group == 'max' or group == 'mean':
                        check = torch.min(new_sims) > torch.min(new_group_average_sims)
                    else:
                        check = torch.mean(new_sims) > torch.mean(new_group_average_sims)
                    if reverse:
                        check = not check 
                    if check:
                        new_group_average_sims = new_sims
                        temp = [group_map[client_j][0], group_map[client_j][1]]
                        group_map[client_j] = [group_map[client_i][0], group_map[client_i][1]]
                        group_map[client_i] = temp
                    else:
                        group_client_indicies[group_map[client_i][0]][group_map[client_i][1]] = client_i
                        group_client_indicies[group_map[client_j][0]][group_map[client_j][1]] = client_j
            if group == 'min' or group == 'max' or group == 'mean':
                print(client_i, torch.min(new_group_average_sims))
            else:
                print(client_i, torch.mean(new_group_average_sims))


        if torch.mean(new_group_average_sims) == torch.mean(group_average_sims) or 1:
            # result_indicies = []
            # for i in range(num_epoch):
            #     current = []
            #     for j in range(round_sample_num):
            #         s = random.sample(list(group_client_indicies[j]), 1)
            #         current.append(s[0])
            #     result_indicies.append(current)
            # print(group_client_indicies)
            group_client_indicies = group_client_indicies.chunk(group_client_indicies.shape[1], dim=1)
            # print(group_client_indicies)
            group_client_indicies = [x.view(-1).unsqueeze(0) for x in group_client_indicies]
            group_client_indicies = torch.cat(group_client_indicies, dim=0)
            return group_client_indicies, group_average_sims
            # return torch.tensor(result_indicies), group_average_sims
        else:
            group_average_sims = new_group_average_sims


def global_search(indexs, round_sample_num):
    group_num = client_num // round_sample_num
    group_client_indicies = torch.tensor([i for i in range(client_num)]).reshape(group_num, round_sample_num)
    group_map = {}
    for i in range(group_num):
        for j in range(round_sample_num):
            group_map[int(group_client_indicies[i][j])] = [i, j]
        
    group_average_sims = torch.tensor([distance_between_groups(group_client_indicies[i], group_client_indicies[(i + 1) % group_num], indexs) for i in range(group_num)])

    while True:
        new_group_average_sims = copy.deepcopy(group_average_sims)
        for client_i in range(client_num):
            for client_j in range(client_num):
                if int(group_map[client_i][0]) != int(group_map[client_j][0]):
                    group_client_indicies[group_map[client_i][0]][group_map[client_i][1]] = client_j
                    group_client_indicies[group_map[client_j][0]][group_map[client_j][1]] = client_i
                    new_sims = torch.tensor([distance_between_groups(group_client_indicies[i], group_client_indicies[(i + 1) % group_num], indexs) for i in range(group_num)])
                    if group == 'min' or group == 'max' or group == 'mean':
                        check = torch.min(new_sims) > torch.min(new_group_average_sims)
                    else:
                        check = torch.mean(new_sims) > torch.mean(new_group_average_sims)
                    if reverse:
                        check = not check 
                    if check:
                        new_group_average_sims = new_sims
                        temp = [group_map[client_j][0], group_map[client_j][1]]
                        group_map[client_j] = [group_map[client_i][0], group_map[client_i][1]]
                        group_map[client_i] = temp
                    else:
                        group_client_indicies[group_map[client_i][0]][group_map[client_i][1]] = client_i
                        group_client_indicies[group_map[client_j][0]][group_map[client_j][1]] = client_j
            if group == 'min' or group == 'max' or group == 'mean':
                print(client_i, torch.min(new_group_average_sims))
            else:
                print(client_i, torch.mean(new_group_average_sims))


        if torch.mean(new_group_average_sims) == torch.mean(group_average_sims):
            return group_client_indicies, group_average_sims
        else:
            group_average_sims = new_group_average_sims
        



def greedy_order(indexs, round_sample_num, pre, num_epoch):
    pres = pre
    group_num = client_num // round_sample_num
    group_client_indicies = []
    group_average_sims = []
    searched = []
    for i in range(num_epoch):
        current_group_client_indicies = []
        current_sims = torch.tensor([distance_between_groups(pres, [k], indexs).item() if k not in searched else -1e9 for k in range(client_num)])
        current_sims = torch.softmax(current_sims / tau, dim=0)
        chosen = []
        # print(current_sims)
        while len(chosen) < round_sample_num:
            current_chosen = random.choices([i for i in range(client_num)], k=1, weights=current_sims)
            if current_chosen[0] not in chosen and current_chosen[0] not in searched:
                current_sims[current_chosen[0]] = 0
                chosen.append(current_chosen[0])
        average_sims = 0.0
        for j in range(round_sample_num):
            current_group_client_indicies.append(chosen[j])
            searched.append(chosen[j])
            average_sims += distance_between_groups(pres, [chosen[j]], indexs)
        while len(searched) > 0.5 * client_num:
            searched.pop(0)
        group_average_sims.append(average_sims / round_sample_num)
        group_client_indicies.append(current_group_client_indicies)

        pres = group_client_indicies[-1]
    
    # average_sims = 0.0
    # for j in range(round_sample_num):
    #     average_sims += distance_between_groups(pres, [group_client_indicies[0][j]], indexs)
    # group_average_sims.append(average_sims / round_sample_num)

    return group_client_indicies, group_average_sims, pres


        



extracted_path = PROJECT_DIR / "datapreprocess" / "indexs" / dataset / "index-summary.pkl"
with open(extracted_path, "rb") as f:
    clients_index = pickle.load(f)

pre = [i for i in range(client_num)]

if order_type == 'greedy':
    group_client_indicies, group_average_sims, pre = greedy_order(clients_index, sample_num_per_round, pre, 500)
elif order_type == 'global':
    group_client_indicies, group_average_sims = global_search(clients_index, sample_num_per_round)
elif order_type == 'group':
    group_client_indicies, group_average_sims = group_search(clients_index, sample_num_per_round, 500)
print(group_client_indicies, sum(group_average_sims) / len(group_average_sims), min(group_average_sims), max(group_average_sims))

saved_orders = {
    'orders': group_client_indicies,
    'sims': group_average_sims
}

# sims = torch.zeros((client_num, client_num))
# sims_1 = torch.zeros((client_num, client_num))
# for i in range(client_num):
#     # print(clients_index[i])
#     for j in range(client_num):
#         sims[i, j] = cos_sim(clients_index[i], clients_index[j])
#         sims_1[i, j] = cos_sim(clients_index[i][:512], clients_index[j][:512])

# print(sims)


os.makedirs(os.path.dirname(PROJECT_DIR / "datapreprocess" / "sample_orders" / dataset / order_type / "{}.pkl".format(sample_num_per_round)),exist_ok=True)

if not reverse:
    new_root = PROJECT_DIR / "datapreprocess" / "sample_orders" / dataset / order_type / "{}-{}-{}-{}.pkl".format(sample_num_per_round, sim_metrics, group, tau)
else:
    new_root = PROJECT_DIR / "datapreprocess" / "sample_orders" / dataset / order_type / "{}-{}-{}-{}-reverse.pkl".format(sample_num_per_round, sim_metrics, group, tau)

with open(new_root, "wb") as f:
    pickle.dump(saved_orders, f)


