# -*- coding: utf-8 -*-
"""
Created on Mon Nov  7 13:45:51 2022

@author: xiato
"""

import numpy as np
import matplotlib.pyplot as plt

heatmap = np.loadtxt('logs_all//data//cifar10_100_quantity_3.0.txt')
heatmap = np.loadtxt('logs_all//data//cifar10_100_distribution_0.1.txt')

heatmap = np.zeros(heatmap.shape)
with open ('logs_all//data//cifar10_100_distribution_0.5.txt') as f:
    for idx, line in enumerate(f):
        if idx>=2:
            temp = [float(x) for x in line.strip().split(',')[1:]]
            nums = [int(temp[i]*temp[-1]) for i in range (10)]
            heatmap[idx-2,:] = nums
            
        
        
        
        
#fig = plt.figure(figsize=(3,2),dpi=500)
fig = plt.figure(figsize=(11,1),dpi=500)


plt.imshow(heatmap.T)
#plt.contourf(XX,YY,heatmap, cmap = 'pink') #,cmap = 'Blues'


plt.yticks([0,2,4,6,8],size=12)
plt.xticks([0,20,40,60,80],size=12)
#plt.yticks([0,1,2,3,4,5,6,7,8,9],['client0','client1','client2','client3','client4','client5','client6','client7','client8','client9'],size=7)

# plt.xlim([175,325])
# plt.ylim([175,325])

plt.xlabel('Client', fontsize=12)
plt.ylabel('Class', fontsize=12)
cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=10) 