import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from argparse import ArgumentParser, Namespace
from pathlib import Path
from evaluate.evaluator import Evaluator
from util.scenario import Scenario
from util.interface import CausalModel

def parse_args() -> Namespace:
    parser = ArgumentParser(description="Evaluate for each sample")
    parser.add_argument("--sample_n_level_3", type=int, default=10, help="Number of False/True samples for each non_root variable for level 3")
    parser.add_argument("--least_needed_sample_level_2", type=int, default=3, help="Least number of samples in a group for level 2")
    parser.add_argument("--llm_names", type=list, 
                        default=["CogVideoX1.5-5b", "Hailuo", "Pika", "pyramid_flux", "CogVideoX-2b", "videocrafter2_base",
                                 "CogVideoX-5b", "HunyuanVideo", "Gen-3-Alpha", "Kling"], # ["CogvideoX-5b", "Pika", "pyramid_flux", "CogVideoX-2b", "videocrafter2_base", "HunyuanVideo"],
                          help="Model to evaluate. Stay empty to evaluate all models")
    parser.add_argument("--scenario_begin", type=int, default=1, help="Ask scenarios from scenario_begin to scenario_end")
    parser.add_argument("--scenario_end", type=int, default=60, help="Ask scenarios from scenario_begin to scenario_end, -1 to ask all scenarios")
    parser.add_argument("--thresholds", type=list, default=[0.65, 0.75, 0.85, 0.95], help="Thresholds used for metric 3.")
    parser.add_argument("--sample_ratio_threshold_level_3", type=float, default=0.2, help="Y with ratio of Y=1 < t or > 1-t will not be used")
    return parser.parse_args()

def main():
    args: Namespace= parse_args()
    database_folder = Path(__file__).resolve().parent.parent / "database"
    all_scenarios = Scenario.get_all_scenarios()
    if len(args.llm_names) == 0:
        llm_names = [folder.name for folder in database_folder.iterdir() if folder.is_dir() and not folder.name.endswith("sample")]
    else:
        llm_names = args.llm_names
    for llm_name in llm_names:
        llm_folder = database_folder / llm_name
        for scenario in all_scenarios:
            if scenario == "test scenario":
                continue
            scenario_id = Scenario.get_index(scenario=scenario)
            if args.scenario_end != -1 and not (args.scenario_begin <= scenario_id <= args.scenario_end):
                continue
            scenario_folder = llm_folder / str(scenario_id)
            model = CausalModel(scenario=scenario)
            evaluator = Evaluator(model=model, sample_n=args.sample_n_level_3, llm_name=llm_name, sample_num_threshold=args.sample_ratio_threshold_level_3)
            evaluator._add_video_paths()

if __name__ == '__main__':
    main()