import os
import pickle
from pathlib import Path
import numpy as np
import torch
import json
import matplotlib.pyplot as plt
from scipy.interpolate import griddata
import seaborn as sns
import pandas as pd
PROJECT_DIR = Path(__file__).parent.parent.absolute()

dataset = 'shakespeare/global/transformer/e500_lr0.01_div_1'
# dataset = 'cifar10/global/transformer/e500_lr0.01_div_1'
# dataset = 'domain/global/transformer/e500_lr0.01_div_1'
# dataset = 'domain/global/mlp/e5000_lr0.01'
client_num = 136

def cos_sim(x, y):
    return torch.nn.functional.cosine_similarity(x, y, dim=0)

extracted_path = PROJECT_DIR / "datapreprocess" / "indexs" / dataset / "index-summary.pkl"
with open(extracted_path, "rb") as f:
    clients_index = pickle.load(f)

sims = torch.zeros((client_num, client_num))
sims_1 = torch.zeros((client_num, client_num))
for i in range(client_num):
    # print(clients_index[i])
    for j in range(client_num):
        sims[i, j] = cos_sim(clients_index[i], clients_index[j])
        sims_1[i, j] = cos_sim(clients_index[i][:512], clients_index[j][:512])
    # print('*'*20)

most_closed = sorted((sims[0][i], i) for i in range(client_num))
most_closed_1 = sorted((sims_1[0][i], i) for i in range(client_num))

print([sim[1] for sim in most_closed])
print([sim[1] for sim in most_closed_1])
print([sim[0] for sim in most_closed_1])

# domain_to_domain = torch.zeros((client_num // 10, client_num // 10))

# for i in range(client_num):
#     for j in range(client_num):
#         domain_to_domain[i // 10][j // 10] += sims_1[i][j] / 100

# print(domain_to_domain)

# data = pd.DataFrame(sims.numpy())
data = pd.DataFrame(sims_1.numpy())

# tick_ = np.arange(-1, 1, 0.1).astype(float)

plot = sns.heatmap(data, cmap='YlGnBu',xticklabels=10,yticklabels=10)
# plot = sns.heatmap(data, cmap='BuPu_r',xticklabels=5,yticklabels=5)

# plt.xticks(range(0, client_num, 5))
# plt.yticks(range(0, client_num, 5))

plt.xlabel('Clients', size=10)
plt.ylabel('Clients', size=10)
plt.title('Client Index Similarity (Global)', size=10)

plt.savefig(PROJECT_DIR / "datapreprocess" / "indexs" / dataset / 'heat-map.pdf')

# os.makedirs(os.path.dirname(PROJECT_DIR / "datapreprocess" / "indexs" / dataset / "parse_result"),exist_ok=True)
# with open(PROJECT_DIR / "datapreprocess" / "indexs" / dataset / "parse_result", mode="w") as f:    
#     f.writelines('{}\n'.format([sim[1] for sim in most_closed]))
#     f.writelines('{}\n'.format([sim[1] for sim in most_closed_1]))
#     f.writelines('{}\n'.format([sim[0] for sim in most_closed_1]))
        

# with open(PROJECT_DIR / "data" / dataset / "all_stats.json", "r") as f:
#     clients_stats = json.load(f)

# for j in range(len(most_closed)):
#     sim_1, i_1 = most_closed[j]
#     sim_2, i_2 = most_closed_1[j]
#     print(str(clients_stats['{}'.format(i)]) + '  ' + str(sim_1) + '  {}'.format(i_1), end='\t\t\t')
#     print(str(clients_stats['{}'.format(i)]) + '  ' + str(sim_2) + '  {}'.format(i_2))

# for sim, i in most_closed_1:
    # print(str(clients_stats['{}'.format(i)]) + '  ' + str(sim) + '  {}'.format(i))


