import matplotlib.pyplot as plt
import os
import sys
import torch
import pandas as pd
import re

res_path = sys.argv[1]

def getinfo(checkpoint_dir, save_path):
    print(checkpoint_dir)
    if not os.path.isfile(os.path.join(checkpoint_dir, 'checkpoint.pth.tar')):
        print('no checkpoint')
        return 

    checkpoint = torch.load(os.path.join(checkpoint_dir, 'checkpoint.pth.tar'), map_location = torch.device('cpu'))
    end_epoch = checkpoint['epoch']
    #assert end_epoch == 160
    #all_result: train_acc val_sa val_ra test_sa test_ra
    all_result = checkpoint['result']
    best_sa = checkpoint['best_sa']
    best_ra = checkpoint['best_ra']
    end_epoch = checkpoint['epoch']
    checkpoint_state = {
        'best_sa': best_sa,
        'best_ra': best_ra,
        'epoch': end_epoch,
        'result': all_result
    }
    filepath = os.path.join(save_path, 'checkpoint.pth.tar')
    torch.save(checkpoint_state, filepath)


def getAllInfo():
    files = os.listdir(res_path)
    files.sort()
    save_dir = './mini_cps'
    for file in files:
        save_path = os.path.join(save_dir, res_path, file)
        os.makedirs(save_path, exist_ok=True)
        getinfo(os.path.join(res_path, file), save_path)   

if __name__ == '__main__':
    getAllInfo()
