import os
import sys
from textwrap import indent
import torch
import pandas as pd

res_path = sys.argv[1]
#res_path = r"/home/pengjun/code_workspace/dst-adv/res"



def getinfo(checkpoint_dir):
    print(checkpoint_dir)
    path = os.path.join(checkpoint_dir, 'checkpoint.pth.tar')
    checkpoint = torch.load(path, map_location = torch.device('cuda:0'))
    best_ra = checkpoint['best_ra_swa']
    end_epoch = checkpoint['epoch']
    print('end_epoch', end_epoch)
    # assert end_epoch == 160
    #all_result: train_acc val_sa val_ra test_sa test_ra
    all_result = checkpoint['result']

    best_val_ra_index = all_result['val_ra_swa'].index(best_ra)
    #this best is to val ra point
    # best_train_ra = round(all_result['train_acc'][best_val_ra_index], 2)
    # final_train_ra = round(all_result['train_acc'][-1], 2)
    best_test_ra =  all_result['test_ra_swa'][best_val_ra_index]
    final_test_ra = all_result['test_ra_swa'][-1]
    diff1 = best_test_ra - final_test_ra


    best_test_sa =  all_result['test_sa_swa'][best_val_ra_index]
    final_test_sa = all_result['test_sa_swa'][-1]
    diff2 = best_test_sa - final_test_sa

    #train ra - test ra
    # val_best = best_train_ra - best_test_ra
    # final = final_train_ra - final_test_ra
    # diff3 = val_best - final

    result = {}
    result['file'] = [os.path.basename(checkpoint_dir)]
    result['best_ra'] = [best_test_ra]
    result['final_ra'] = [final_test_ra]
    result['diff1'] = [diff1]
    result['best_sa'] = [best_test_sa]
    result['final_sa'] = [final_test_sa]
    result['diff2'] = [diff2]

    # result['generalization_gap'] = [val_best]
    # result['final_gap'] = [final]
    # result['diff3'] = [diff3]
    if all_result.get('sparsity'):
        result['sparsity'] = round(all_result['sparsity'], 2)
    
    if all_result.get('total_fired_weights'):
        result['total_fired_weights'] = round(all_result['total_fired_weights'], 2)

    return pd.DataFrame(result)
    

    

    
def getAllInfo():
    files = os.listdir(res_path)

    df_total = pd.DataFrame(columns=['file', 'best_ra', 'final_ra', 'diff1', 'best_sa', 'final_sa', 'diff2', 'sparsity', 'total_fired_weights'])

    for file in files:
        df = getinfo(os.path.join(res_path, file))
        df_total = df_total.append(df, ignore_index=True)

    df_total = df_total.sort_values(by='file',ascending=False)

    print(df_total)
    file_name = '{}_statis_swa.xlsx'.format(os.path.basename(res_path))
    save_dir = './statis'
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, file_name)
    df_total.to_excel(save_path, index=False)


if __name__ == '__main__':
    getAllInfo()
