import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("darkgrid")

algorithm = 'conceptEM'
dataset = 'cifar10-c-2swap'
rhos = ['3', '1', '05']

counts = []

for r, rho in enumerate(rhos):
    count_file = './logs/{}/cluster-weights-{}_SW-0.05-03-lr-0{}-04-adapt-split-12.txt'.format(dataset, algorithm, rho)
    count_file_1 = './logs/{}/cluster-weights-{}_SW-0.05-03-lr-0{}-04-adapt-split-1.txt'.format(dataset, algorithm, rho)

    with open(count_file, 'r') as f:
        lines = f.readlines()
        lines = [x.strip() for x in lines]
        lines = [x.split(', ') for x in lines]
        count_1 = [len(x) for x in lines]
    
    with open(count_file_1, 'r') as f:
        lines = f.readlines()
        lines = [x.strip() for x in lines]
        lines = [x.split(', ') for x in lines]
        count_2 = [len(x) for x in lines]

    count = [round((count_1[i] + count_2[i]) / 2) for i in range(len(count_1))]

    end = min(len(count), 200)
    x = [i for i in range(end) if i % 2 == 0]

    plt.plot(x, count[:end:2], label='rho = 0.' + rho, color=sns.color_palette('deep')[r])

plt.legend(loc=4)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.xlabel('Communication Rounds', fontsize=18)
plt.ylabel('Number of Clusters', fontsize=18)
# plt.title('CIFAR10 dataset')
plt.savefig('./plots/{}-adapt-{}.pdf'.format(algorithm, dataset))
