import argparse
import json
import os
import sys

from llava_evaluator_batch import eval_one_video

ALL_EP_ID_TO_TEST = ['P03_22', 'P03_23', 'P04_24', 'P04_25', 'P04_33', 'P06_10', 'P06_11', 'P06_12', 'P06_13', 'P06_14', 'P08_16', 'P08_17', 'P09_07', 'P18_01', 'P18_02', 'P18_03', 'P18_06', 'P18_07', 'P23_05', 'P30_07']

FIRST_HALF_TO_TEST = ['P03_22', 'P03_23', 'P04_24', 'P04_25', 'P04_33', 'P06_10', 'P06_11', 'P06_12', 'P06_13', 'P06_14']

SECOND_HALF_TO_TEST = ['P08_16', 'P08_17', 'P09_07', 'P18_01', 'P18_02', 'P18_03', 'P18_06', 'P18_07', 'P23_05', 'P30_07']


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-v', '--prompt_version', type=str, required=True)
    parser.add_argument('--eval_all', action="store_true", help="if included, evaluate all the video")
    parser.add_argument('--eval_1st_half', action="store_true", help="if included, evaluate the first half of the video")
    parser.add_argument('--eval_2nd_half', action="store_true", help="if included, evaluate the second half of the video")
    parser.add_argument('-l', '--ep_id_list', nargs='+', help="a list of ep id to generate")
    parser.add_argument('--model', type=str, default='gpt-3.5-turbo')

    main_args = parser.parse_args()

    overall_eval_dict_fp = os.path.join("eval_data", main_args.prompt_version, f"{main_args.prompt_version}_overall_eval.json")

    if os.path.exists(overall_eval_dict_fp):
        with open(overall_eval_dict_fp, "r") as fin:
            overall_eval_dict = json.load(fin)
    else:
        overall_eval_dict = {}

    ep_id_to_gen = []
    if main_args.eval_all:
        if input("WARNING! It's better to parallelize the work by having 2 terminal (one does first half, another does second half).\nAre you sure you want to use 1 TERMINAL to general for all video? (y/n): ") != "y":
            sys.exit()
        else:
            ep_id_to_gen = ALL_EP_ID_TO_TEST
    elif main_args.eval_1st_half:
        ep_id_to_gen = FIRST_HALF_TO_TEST
    elif main_args.eval_2nd_half:
        ep_id_to_gen = SECOND_HALF_TO_TEST
    else:
        ep_id_to_gen = main_args.ep_id_list

    print(ep_id_to_gen)

    for ep_id in ep_id_to_gen:
        has_ann_file = os.path.exists(os.path.join("eval_data/annotation", f"{ep_id}_ann.json"))
        has_pred_file = os.path.exists(os.path.join("raw_data", main_args.prompt_version, f"{main_args.prompt_version}_{ep_id}_output.json"))

        print(f"{ep_id}, has_ann_file={has_ann_file}, has_pred_file={has_pred_file}")
        if has_ann_file and has_pred_file:
            eval_dict = eval_one_video(ep_id, main_args.prompt_version, main_args.model)

            overall_eval_dict[ep_id] = eval_dict

            with open(overall_eval_dict_fp, "w") as fout:
                fout.write(json.dumps(overall_eval_dict, indent=4))