#srun --pty --time=2:00 --gres=gpu:v100:2 --mem=64G --cpus-per-task=4 --resv-ports=1  --account=conf-2020-neurips bash -l
import os 
import glob as glob 
import argparse 
import torch

parser = argparse.ArgumentParser()
parser.add_argument('--v', default='10',
                    help='personal tag for the model ')
parser.add_argument('--dir', default='N',
                    help='personal tag for the model ')                    
args = parser.parse_args()
args.dir = './output/resnet_cifar_{0}/'.format(args.v)
print(args)

chk_lst = 4 

completed_jobs = [] 
incompleted_jobs = [] 

def name_analyzer(dltitle):
    d = dltitle.replace(args.dir,"")
    #Model 	 ./output/resnet_cifar_10/save_resnet110_iea_nn_1_4_ 	 93.92
    ds = d.split("_")
    ds.pop()
    Mense = int(ds[-1].replace(" ",""))
    ds.pop()
    version = ds[-1].replace(" ","")
    ds.pop()
    
    # print("DBG|",ds[-2].replace(" ","")+"_"+ds[-1].replace(" ",""))
    if ds[-2].replace(" ","")+"_"+ds[-1].replace(" ","") == "drop_iea":
        model_title = ds[-2].replace(" ","")+"_"+ds[-1].replace(" ","")
        ds.pop()
        ds.pop()
    elif ds[-2].replace(" ","")+"_"+ds[-1].replace(" ","") == "iea_nn":
        model_title = ds[-2].replace(" ","")+"_"+ds[-1].replace(" ","")
        ds.pop()
        ds.pop()
    elif ds[-1] == "normal":
        model_title = ds[-1].replace(" ","")
        ds.pop()
    elif ds[-1] =="iea":
        model_title = ds[-1].replace(" ","")
        ds.pop()
    elif ds[-1] == "maxout":
        model_title = ds[-1].replace(" ","")
        ds.pop()
    elif ds[-1] == "base":
        model_title = ds[-1].replace(" ","")
        ds.pop()
    elif ds[-1] == "drop":
        model_title = ds[-1].replace(" ","")
        ds.pop()

    arch = ds[-1]
    ds.pop()
    return Mense,version,model_title,arch
    # print(dltitle, "= ||","M:",Mense,"V:",version,"Type:",model_title,"Arch:",arch)

# args.save_dir+"_"+str(args.model_type)+"_"+str(args.version)+"_"+str(args.Mense)+"_"

dirs = glob.glob(args.dir+"*/")
version_accept = [0,1,2,3,4]
M_accept = [2,4,8,16]
model_accept = ["normal","iea","maxout","iea_nn"]
arch_accept = ["resnet56","resnet110"]
f_ = open("resnet_cifar{0}_results.csv".format(args.v),"w")
f_.write("Mense,version,model_title,arch,acc,pacc\n")
for d in dirs: 
    fles = glob.glob(d+"/*")
    Mense,version,model_title,arch = name_analyzer(d)
    if int(version) not in version_accept:
        continue
    if int(Mense) not in M_accept:
        continue  
    if model_title not in model_accept:
        continue
    if arch not in arch_accept:
        continue
    if os.path.exists(d+"/epoch.txt"): #New case
        with open(d+'/epoch.txt', 'r') as handle:
            epo = int(handle.readline())
            if epo == 199:
                completed_jobs.append(d)
                checkpoint = torch.load(d+"/model.th")
                # print("Model","\t",d,"\t",checkpoint['best_prec1'])
                if model_title in["iea_nn","iea","maxout"]:
                    with open(d+'/iea_acc.txt',"r") as poofdd:
                        iea_acc = float(poofdd.readline())
                    f_.write("{0},{1},{2},{3},{4},{5}\n".format(Mense,version,model_title,arch,checkpoint['best_prec1'],iea_acc))

                else:
                    f_.write("{0},{1},{2},{3},{4},_\n".format(Mense,version,model_title,arch,checkpoint['best_prec1']))

            else:
                incompleted_jobs.append(d)
                print(epo)

    else:
        if len(fles) >= chk_lst: #old cases
            completed_jobs.append(d)
            checkpoint = torch.load(d+"/model.th")
            if model_title in["iea_nn","iea","maxout"]:
                with open(d+'/iea_acc.txt',"r") as poofdd:
                    iea_acc = float(poofdd.readline())
                f_.write("{0},{1},{2},{3},{4},{5}\n".format(Mense,version,model_title,arch,checkpoint['best_prec1'],iea_acc))

            else:
                f_.write("{0},{1},{2},{3},{4},_\n".format(Mense,version,model_title,arch,checkpoint['best_prec1']))
            # print("Model","\t",d,"\t",checkpoint['best_prec1'])
        else:
            incompleted_jobs.append(d)

f_.write("Completed jobs\n")
f_.write(str(completed_jobs))

f_.write("\nIncompleted jobs\n")
f_.write(str(incompleted_jobs))
f_.close()