import matplotlib.pyplot as plt
import seaborn as sns

file = './logs/cifar10-c-2swap/sims-conceptEM_SW-03-lr-03-0-adapt-split-distance-12-50.txt'

# file = './logs/cifar10-c-2swap/sims-FedEM_SW-03-lr-03-0-adapt-split-distance-1-50.txt'


with open(file, 'r') as f:
    lines = f.readlines()
    clients = lines[:100]
    sims = lines[100:]
    for i in range(100):
        clients[i] = clients[i].split('\t')
        clients[i][0] = int(clients[i][0])
        if clients[i][0] == 0 or clients[i][0] == 1:
            clients[i][0] = 0
        elif clients[i][0] == 2 or clients[i][0] == 4:
            clients[i][0] = 1
        else:
            clients[i][0] = 2
        clients[i][1] = int(clients[i][1])
        clients[i][2] = clients[i][2][1:-2].split(', ')[:-1]
        # print(clients[i][2])
        clients[i][2] = [int(x) for x in clients[i][2]]
        clients[i][2] = clients[i][2].index(max(clients[i][2]))
        clients[i][3] = clients[i][3][8:-3].split(', ')
        clients[i][3] = [float(x) for x in clients[i][3]]
        cluster_num = len(clients[i][3])
        clients[i][3] = clients[i][3].index(max(clients[i][3]))
        clients[i].append(i)
        
        sims[i] = sims[i].split('\t')[:-1]
        sims[i] = [float(x) for x in sims[i]]

clients = sorted(clients, key=lambda x: x[0])
nums = 100
clusters_distances = [[0] * cluster_num for _ in range(cluster_num)]
clusters_max_distances = [[0] * cluster_num for _ in range(cluster_num)]
clusters_count = [[0] * cluster_num for _ in range(cluster_num)]
for i in range(nums):
    for j in range(nums):
        clusters_distances[clients[i][3]][clients[j][3]] += sims[i][j]
        clusters_max_distances[clients[i][3]][clients[j][3]] = max(sims[i][j], clusters_max_distances[clients[i][3]][clients[j][3]])
        clusters_count[clients[i][3]][clients[j][3]] += 1
for i in range(cluster_num):
    clusters_distances[i] = [x / clusters_count[i][i] for x in clusters_distances[i]]

# print(clusters_distances)
# print(clusters_max_distances)
print([[clusters_max_distances[i][j] - clusters_distances[i][j] if i == j else clusters_max_distances[i][j] for i in range(cluster_num)] for j in range(cluster_num)])


new_sims = [[0] * 3 for i in range(3)]
new_max_sims = [[0] * 3 for i in range(3)]
counts = [[0] * 3 for i in range(3)]
for i in range(nums):
    for j in range(nums):
        new_sims[clients[i][0]][clients[j][0]] += sims[i][j]
        new_max_sims[clients[i][0]][clients[j][0]] = max(new_max_sims[clients[i][0]][clients[j][0]], sims[i][j])
        counts[clients[i][0]][clients[j][0]] += 1
for i in range(3):
    for j in range(3):
        new_sims[i][j] /= counts[i][j]

# new_sims = [[0] * 10 for i in range(10)]
# nums = 100
# counts = [[0] * 10 for i in range(10)]
# for i in range(nums):
#     for j in range(nums):
#         new_sims[clients[i][2]][clients[j][2]] += sims[i][j]
#         counts[clients[i][2]][clients[j][2]] += 1
# for i in range(10):
#     for j in range(10):
#         new_sims[i][j] /= counts[i][j]

print(counts)
print(new_sims)
print(new_max_sims)

# nums = 10
# new_sims = [[0] * nums for _ in range(nums)]
# for i in range(nums):
#     for j in range(nums):
        # new_sims[i][j] = sims[clients[i][-1]][clients[j][-1]]
        # new_sims[i][j] = sims[i][j]

# sns.heatmap(new_sims, cmap='viridis', cbar=False)
# # plt.xticks(fontsize=12)
# # plt.yticks(fontsize=12)
# plt.xlabel('Client')
# plt.ylabel('Client')
# plt.savefig('./plots/sims-conceptEM-100.pdf')
