import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from argparse import ArgumentParser, Namespace
from pathlib import Path
import json
import shutil
import pandas as pd
from evaluate.evaluator import Evaluator
from evaluate.bootstrap import BootStraper
from util.scenario import Scenario
from util.interface import CausalModel

def parse_args() -> Namespace:
    parser = ArgumentParser(description="Evaluate for each sample")
    parser.add_argument("--sample_n_level_3", type=int, default=10, help="Number of False/True samples for each non_root variable for level 3")
    parser.add_argument("--least_needed_sample_level_2", type=int, default=3, help="Least number of samples in a group for level 2")
    parser.add_argument("--llm_names", type=list, 
                        default=["CogVideoX1.5-5b", "Hailuo", "Pika", "pyramid_flux", "CogVideoX-2b", "videocrafter2_base",
                                 "CogVideoX-5b", "HunyuanVideo", "Gen-3-Alpha", "Kling"], # ["CogvideoX-5b", "Pika", "pyramid_flux", "CogVideoX-2b", "videocrafter2_base", "HunyuanVideo"],
                          help="Model to evaluate. Stay empty to evaluate all models")
    parser.add_argument("--scenario_begin", type=int, default=1, help="Ask scenarios from scenario_begin to scenario_end")
    parser.add_argument("--scenario_end", type=int, default=60, help="Ask scenarios from scenario_begin to scenario_end, -1 to ask all scenarios")
    parser.add_argument("--thresholds", type=list, default=[0.65, 0.75, 0.85, 0.95], help="Thresholds used for metric 3.")
    parser.add_argument("--sample_ratio_threshold_level_3", type=float, default=0.2, help="Y with ratio of Y=1 < t or > 1-t will not be used")
    return parser.parse_args()

def main():
    args: Namespace= parse_args()
    database_folder = Path(__file__).resolve().parent.parent / "database"
    all_scenarios = Scenario.get_all_scenarios()
    if len(args.llm_names) == 0:
        llm_names = [folder.name for folder in database_folder.iterdir() if folder.is_dir() and not folder.name.endswith("sample")]
    else:
        llm_names = args.llm_names
    for llm_name in llm_names:
        llm_folder = database_folder / llm_name
        for scenario in all_scenarios:
            if scenario == "test scenario":
                continue
            scenario_id = Scenario.get_index(scenario=scenario)
            if args.scenario_end != -1 and not (args.scenario_begin <= scenario_id <= args.scenario_end):
                continue
            scenario_folder = llm_folder / str(scenario_id)
            results_folder = scenario_folder / "sample_results"
            shutil.rmtree(results_folder)
            model = CausalModel(scenario=scenario)
            evaluator = Evaluator(model=model, sample_n=args.sample_n_level_3, llm_name=llm_name, sample_num_threshold=args.sample_ratio_threshold_level_3)
            evaluator.evaluate(least_needed_sample_level_2=args.least_needed_sample_level_2)

if __name__ == '__main__':
    main()