from argparse import ArgumentParser, Namespace
from pathlib import Path
from util.scenario import Scenario
from handle.video_asker import VideoAskerV1
from handle.handler import Handler
from util.interface import CausalModel

def parse_args() -> Namespace:
    parser = ArgumentParser(description="Ask for observed variables and evaluate.")
    parser.add_argument("--sample_n_level_3", type=int, default=10, help="Number of False/True samples for each non_root variable for level 3")
    parser.add_argument("--llm_names", type=list, 
                        default=["CogVideoX1.5-5b", "Hailuo", "Pika", "pyramid_flux", "CogVideoX-2b", "videocrafter2_base",
                                 "CogVideoX-5b", "HunyuanVideo", "Gen-3-Alpha", "Kling"],
                        help="Name of models to be asked. Stay empty to ask all models")
    parser.add_argument("--scenario_begin", type=int, default=1, help="Ask scenarios from scenario_begin to scenario_end")
    parser.add_argument("--scenario_end", type=int, default=60, help="Ask scenarios from scenario_begin to scenario_end, -1 to ask all scenarios")
    return parser.parse_args()

def main():
    args: Namespace= parse_args()
    database_folder = Path(__file__).parent / "database"
    if len(args.llm_names) == 0:
        llm_names = [folder.name for folder in database_folder.iterdir() if folder.is_dir() and not folder.name.endswith("sample")]
    else:
        llm_names = args.llm_names
    all_scenarios = Scenario.get_all_scenarios()
    for llm_name in llm_names:
        for scenario in all_scenarios:
            if scenario == "test scenario":
                continue
            scenario_id = Scenario.get_index(scenario=scenario)
            if args.scenario_end != -1 and not (args.scenario_begin <= scenario_id <= args.scenario_end):
                continue
            model = CausalModel(scenario=scenario)
            video_asker = VideoAskerV1(model=model, llm_name=llm_name)
            handler = Handler(model=model, video_asker=video_asker, sample_n=args.sample_n_level_3,
                              llm_name=llm_name)
            handler.ask_videos()

if __name__ == '__main__':
    main()