import os
import lib
import json
import copy
from typing import List, Optional
from tqdm import tqdm


my_dir = os.path.dirname(__file__)
main_dir = os.path.abspath(my_dir+"/../..")
my_rel_dir = os.path.relpath(my_dir, main_dir)
curr_dir = os.getcwd()

result_name ="result_"
TESTS = "-lm.eval.lambada.enabled 1 -lm.eval.cbt.enabled 1 -lm.eval.hellaswag.enabled 1 -lm.eval.piqa.enabled 1 -lm.eval.blimp.enabled 1 -lm.eval.ai2arc.enabled 1 -lm.eval.race.enabled 1 -lm.eval.siqa.enabled 1"

def get_info(id, patch_ckpt=None, bs: Optional[int] = None):
    dest_dir = f"{id}/"
    path_checkpoint = os.path.basename(patch_ckpt).split('.')[0]
    res_path = f"{dest_dir}{result_name}_{path_checkpoint}.json"

    if not os.path.isfile(res_path) or True:

        if patch_ckpt is not None:
            ckpt_path = patch_ckpt

        if bs is None:
            bs = ""
        else:
            bs = f"--batch_size {bs}"

        cmd = f"python3 main.py --name post_validate --restore {ckpt_path} --test_only 1 -reset 1 -lm.eval.enabled 1 {TESTS} --keep_alive 0 {bs}"
        print("Validate command: ", cmd)
        out = lib.run_command(cmd)
        lines = out.splitlines()
        start_line = lines.index('Validate returned:')
        end_line = None
        for i in range(start_line, len(lines)):
            if lines[i].startswith("-------"):
                end_line = i
                break

        assert end_line is not None

        res = "\n".join(lines[start_line+1:end_line])
        os.chdir(curr_dir)

        with open(res_path, "w") as f:
            f.write(res)

    with open(res_path, "r") as f:
        res = json.load(f)

    return res


if __name__ == "__main__":
    
 
    list_eval = [
      
    ]
    for id, path_weight, bs  in list_eval:
        get_info(id, path_weight, bs)
