import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')


def main():
    label_size = 18
    tick_size = 16
    linewidth = 3
    train_iters = [312, 625, 938, 1564, 3129, 6263, 9397, 12531, 15665]
    performance = [0.5157, 0.7602, 0.9106, 0.9925, 1, 0.9921, 0.924, 1, 0.9654]
    impurity = [0.19356114, 0.23143111, 0.36321658, 0.17739552, 0.16393591, 0.14986867, 0.0992323, 0.061316922, 0.038101535]

    fig, axes = plt.subplots(1, 2, figsize=(6.4*2.3,4.8))

    train_iters = np.array(train_iters) / 1e3

    ax = axes[0]
    ax.plot(train_iters, impurity, linewidth=linewidth)
    ax.set_xlabel('Training Iterations (1e3)', fontsize=label_size)
    ax.set_ylabel('Disentanglement', fontsize=label_size)
    ax.tick_params(axis='both', which='major', labelsize=tick_size)
    ax.tick_params(axis='both', which='minor', labelsize=tick_size)

    ax = axes[1]
    ax.plot(train_iters, performance, linewidth=linewidth)
    ax.set_xlabel('Training Iterations (1e3)', fontsize=label_size)
    ax.set_ylabel('Performance', fontsize=label_size)
    ax.tick_params(axis='both', which='major', labelsize=tick_size)
    ax.tick_params(axis='both', which='minor', labelsize=tick_size)

    fig.tight_layout()
    fig.canvas.draw()
    fig.subplots_adjust(wspace=0.4) #left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)
    fig.savefig('./local/test.pdf')


if __name__ == '__main__':
    main()
