import os
import argparse
from ..utils import create_experiment_id


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--run_dir", 
        type=str, 
        required=True,
        help="The directory of the run."
    )
    parser.add_argument(
        "--save_dir", 
        type=str, 
        default=None,
        help="The directory to save the evaluation results."
    )
    parser.add_argument(
        "--experiment_id",
        type=str,
        default=None,
        help="The experiment id."
    )
    parser.add_argument(
        "--remove_ckpt_dir",
        action="store_true",
        help="If given, the checkpoint directory will be removed after evaluation."
    )
    args = parser.parse_args()
    return args


def eval_dialog_sum_run(
        run_dir, save_dir=None, experiment_id=None, remove_ckpt_dir=False):
    if experiment_id is None:
        experiment_id = create_experiment_id()
    if save_dir is None:
        save_dir = os.path.join(run_dir, 'eval_results')
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
        print(f'{save_dir} does not exist. Created.')

    ckpt_dir_dict = {d.split('-')[-1]: os.path.join(run_dir, d) for d in os.listdir(run_dir) 
                     if d.startswith('checkpoint-')}
    for ckpt_step, ckpt_dir in ckpt_dir_dict.items():
        ckpt_save_dir = os.path.join(save_dir, f'checkpoint-{ckpt_step}')
        eval_command = f'python -m eval.dialog_sum.eval_ckpt --model_name_or_path {ckpt_dir} '\
            f'--save_dir {ckpt_save_dir} --experiment_id {experiment_id} --use_vllm'
        if remove_ckpt_dir:
            rm_command = f'rm -rf {ckpt_dir}'
            eval_command = f'{eval_command} && {rm_command}'
        print(eval_command)
        os.system(eval_command)


def main():
    args = parse_args()
    eval_dialog_sum_run(args.run_dir, args.save_dir, args.experiment_id, args.remove_ckpt_dir)


if __name__ == "__main__":
    main()