import re
from matplotlib import pyplot as plt
import seaborn as sns

def get_train_val_loss(path):
    with open(path, 'r') as f:
        lines = f.readlines()

    train_losses = []
    val_losses = []

    cnt = 0
    for line in lines:
        if '%' in line:
            cnt += 1
            if cnt > 2 and cnt % 2 == 1:
                train_loss =float(re.findall(r"train_loss=\d+\.?\d*", line)[0].split('=')[-1])
                train_losses.append(train_loss)
                try:
                    val_loss = float(re.findall(r"val_loss=\d+\.?\d*", line)[0].split('=')[-1])
                    val_losses.append(val_loss)
                except:
                    pass
    return train_losses, val_losses

train_loss_ft, val_loss_ft = get_train_val_loss('train.log')
train_loss_gno, val_loss_gno = get_train_val_loss('train-gno.log')
train_loss_fz_pb, val_loss_fz_pb = get_train_val_loss('train-frozen.log')
train_loss_fz_pn, val_loss_fz_pn = get_train_val_loss('train-frozen-pointnext.log')

#print(train_loss)

color = '#7a8c7b'
color1 = '#038355'
color2 = '#ffc34f'
color3 = '#3d5e80'
color4 = '#c85454'

x_train = list(range(1, 21))

train_loss_ft = [train_loss_ft[i-1] for i in x_train]
train_loss_gno = [train_loss_gno[i-1] for i in x_train]
train_loss_fz_pb = [train_loss_fz_pb[i-1] for i in x_train]
train_loss_fz_pn = [train_loss_fz_pn[i-1] for i in x_train]

x_test = list(range(1, 201, 10)) + [200]

font = {'family' : 'Times New Roman', 'size': 12}
plt.rc('font', **font)

sns.set_style('whitegrid')
#sns.lineplot(x=list(x_train), y=train_loss, linewidth=2.0, marker='o', markersize=1, markeredgewidth=1.5, label='Line 1')
#sns.lineplot(x=x_train, y=train_loss_gno, color=color1, linewidth=3.0, marker='o', markersize=4, markeredgewidth=0.5,markeredgecolor=color,label='None')
#sns.lineplot(x=x_train, y=train_loss_fz_pn, color=color2, linewidth=3.0, marker='s', markersize=4, markeredgewidth=0.5,markeredgecolor=color,label='PointNeXt (frozen)')
#sns.lineplot(x=x_train, y=train_loss_fz_pb, color=color3, linewidth=3.0, marker='^', markersize=4, markeredgewidth=0.5,markeredgecolor=color,label='Point-BERT (frozen)')
#sns.lineplot(x=x_train, y=train_loss_ft, color=color4, linewidth=3.0, marker='D', markersize=4, markeredgewidth=0.5,markeredgecolor=color,label='Point-BERT (fine-tuned)')
sns.lineplot(x=x_train, y=train_loss_gno, color=color1, linewidth=4.0, marker='o', markersize=4,label='None')
sns.lineplot(x=x_train, y=train_loss_fz_pn, color=color2, linewidth=4.0, marker='s', markersize=4,label='PointNeXt (frozen)')
sns.lineplot(x=x_train, y=train_loss_fz_pb, color=color3, linewidth=4.0, marker='^', markersize=4,label='Point-BERT (frozen)')
sns.lineplot(x=x_train, y=train_loss_ft, color=color4, linewidth=4.0, marker='D', markersize=4,label='Point-BERT (fine-tuned)')
#sns.lineplot(x=x_test, y=val_loss_ft, linewidth=2.0, marker='s', markersize=0, markeredgewidth=1.5, label='Line 3')
#sns.lineplot(x=x_test, y=val_loss_gno, linewidth=2.0, marker='s', markersize=0, markeredgewidth=1.5, label='Line 4')

plt.title('GNO backbone', fontsize=20)
plt.xlabel('Epoch', fontsize=20)
plt.ylabel('Training Loss', fontsize=20)

plt.legend(loc='upper right', frameon=True, fontsize=16)
plt.xticks([2*i for i in range(11)], fontsize=16)
plt.yticks(fontsize=16)
plt.xlim(0,21)
plt.ylim(0,2)

for spine in plt.gca().spines.values():
    spine.set_edgecolor('#CCCCCC')
    spine.set_linewidth(1.5)

plt.savefig('gno-train.pdf', dpi=1200, bbox_inches='tight')
plt.show()