import matplotlib.pyplot as plt
import os
import sys
import torch
import pandas as pd
import re

res_path = sys.argv[1]

# keyword = train_acc test_sa test_ra
def getinfo(checkpoint_dir, keyword, flag):
    print(checkpoint_dir)
    checkpoint = torch.load(os.path.join(checkpoint_dir, 'checkpoint.pth.tar'), map_location = torch.device('cuda:0'))
    all_result = checkpoint['result']

    plt.plot(all_result[keyword], label=flag + keyword)


def getAllInfo():
    save_dir = './images'
    keywords = ['train_acc', 'test_sa', 'test_ra']

    os.makedirs(save_dir, exist_ok=True)
    files = os.listdir(res_path)
    files.sort()
    flag = ''
    plt.figure(1)
    index = 1
    for keyword in keywords:
        for file in files:
            l = file.split('_')
            if 'dense' in l:
                flag = 'dense '
            else:
                a = re.findall("\d+[\.\d+]*", l[2])
                flag = a[0] + ' ' + l[3][0] + ' '
            getinfo(os.path.join(res_path, file), keyword, flag)
        
        plt.ylim(0, 100)
        plt.legend(bbox_to_anchor=(1.05, 0), loc=3, borderaxespad=0)
        plt.subplots_adjust(right=0.7)
        plt.savefig(os.path.join(save_dir, '{}.png'.format(keyword)))
        plt.close()
        index += 1
        plt.figure(index)


if __name__ == '__main__':
    getAllInfo()


