from eval_kit import EvalKit
import os
import pandas as pd
from fire import Fire

def eval_model(model_path, data_path):
    print(f"evaluate {model_path['name']}")
    loss_output_path = model_path["path"].replace(".pt", "_loss.csv")
    accuracy_output_path = model_path["path"].replace(".pt", "_accuracy.csv")
    if os.path.exists(loss_output_path) and os.path.exists(accuracy_output_path):
        print(f"skip {model_path['name']} because {loss_output_path} and {accuracy_output_path} exist")
        return
    evaluator = EvalKit(model_path["path"])
    df = evaluator.eval_pretrain_by_profile(data_path=f"data_law/{data_path}", output_path=loss_output_path, every_n=40, processes_per_gpu=1)
    df = evaluator.eval_pretrain_by_accuracy(data_path=f"data_law/{data_path}", output_path=accuracy_output_path, every_n=40, processes_per_gpu=1)
    return df

def eval_all_checkpoints(model_dir, data_path, strategy="last"):
    print(f"evaluate all checkpoints in {model_dir}")
    if strategy == "last":
        # choose the checkpoint with the largest state_step
        ckpts = [
            x for x in os.listdir(model_dir) if x.startswith("state_step") and x.endswith(".pt")
        ]
        ckpts.sort(key=lambda x: int(x.split("state_step")[1].split(".")[0]))
        ckpts = [ckpts[-1]]
    elif strategy == "all":
        ckpts = [
            x for x in os.listdir(model_dir) if x.startswith("state_step") and x.endswith(".pt")
        ]
    print(ckpts)
    for ckpt in ckpts:
        model_path = {"name": ckpt, "path": os.path.join(model_dir, ckpt)}
        eval_model(model_path, data_path)

if __name__ == "__main__":
    Fire(eval_all_checkpoints)




