import matplotlib.pyplot as plt
import os
import sys
import torch
import pandas as pd
import re

res_path = sys.argv[1]

def getinfo(checkpoint_dir, flag):
    print(checkpoint_dir)
    checkpoint = torch.load(os.path.join(checkpoint_dir, 'checkpoint.pth.tar'), map_location = torch.device('cuda:0'))
    # end_epoch = checkpoint['epoch']
    # assert end_epoch == 160
    #all_result: train_acc val_sa val_ra test_sa test_ra
    all_result = checkpoint['result']

    plt.plot(all_result['train_acc'], label=flag + 'train_acc')
    # plt.plot(all_result['test_sa'], label=flag + 'SA', linestyle = "--")
    # plt.plot(all_result['test_ra'], label=flag + 'RA', linestyle = "-.")
    plt.plot(all_result['test_sa'], label=flag + 'SA')
    plt.plot(all_result['test_ra'], label=flag + 'RA')


def getAllInfo():
    save_dir = './images'
    os.makedirs(save_dir, exist_ok=True)
    files = os.listdir(res_path)
    files.sort()
    flag = ''
    plt.figure(1)
    index = 1
    count = 0
    for file in files:
        if count % 4 == 0 and count != 0:
            plt.ylim(0, 100)
            plt.legend(bbox_to_anchor=(1.05, 0), loc=3, borderaxespad=0)
            plt.subplots_adjust(right=0.7)
            plt.savefig(os.path.join(save_dir, '{}.png'.format(index)))
            plt.close()
            index += 1
            plt.figure(index)

        l = file.split('_')
        if 'dense' in l:
            flag = 'dense '
        else:
            a = re.findall("\d+[\.\d+]*", l[2])
            flag = a[0] + ' ' + l[3][0] + ' '
        getinfo(os.path.join(res_path, file), flag)
        
        count += 1

    #保存最后一个
    plt.ylim(0, 100)
    plt.legend(bbox_to_anchor=(1.05, 0), loc=3, borderaxespad=0)
    plt.subplots_adjust(right=0.7)
    plt.savefig(os.path.join(save_dir, '{}.png'.format(index)))
    plt.close()




if __name__ == '__main__':
    getAllInfo()


