import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from pathlib import Path
from pandas import DataFrame
from argparse import ArgumentParser, Namespace
from evaluate.evaluator import Evaluator
from util.scenario import Scenario
from util.interface import CausalModel

def parse_args() -> Namespace:
    parser = ArgumentParser(description="Evaluate")
    parser.add_argument("--sample_n_level_3", type=int, default=10, help="Number of False/True samples for each non_root variable for level 3")
    parser.add_argument("--by_truth_level_2", type=bool, default=True, help="Group by true value (True) or observed value (False)")
    parser.add_argument("--least_needed_sample_level_2", type=int, default=3, help="Least number of samples in a group for level 2")
    parser.add_argument("--by_truth_level_3", type=bool, default=True, help="Fit rule by ground truth roots or observed roots")
    parser.add_argument("--llm_names", type=list, 
                        default=["CogVideoX1.5-5b", "Hailuo", "Pika", "pyramid_flux", "CogVideoX-2b", "videocrafter2_base",
                                 "CogVideoX-5b", "HunyuanVideo", "Gen-3-Alpha", "Kling"], #["Gen-3-Alpha", "CogVideoX1.5-5b", "Hailuo", "Pika", "pyramid_flux", "CogVideoX-2b", "videocrafter2_base", "CogVideoX-5b", "HunyuanVideo"],
                        help="Model to evaluate. Stay empty to evaluate all models")
    parser.add_argument("--scenario_begin", type=int, default=1, help="Ask scenarios from scenario_begin to scenario_end")
    parser.add_argument("--scenario_end", type=int, default=60, help="Ask scenarios from scenario_begin to scenario_end, -1 to ask all scenarios")
    parser.add_argument("--thresholds", type=list, default=[0.65, 0.75, 0.85, 0.95], help="Thresholds used for metric 3.")
    parser.add_argument("--sample_ratio_threshold_level_3", type=float, default=0.001, help="Y with ratio of Y=1 < t or > 1-t will not be used")

    # parser.add_argument("--nan_as_fault_level_1", type=bool, default=False, help="See NAN as fault in level 1. If False, not use NAN data.")
    return parser.parse_args()

def main():
    args: Namespace= parse_args()
    database_folder = Path(__file__).resolve().parent.parent / "database"
    all_scenarios = Scenario.get_all_scenarios()
    thresholds = args.thresholds
    columns = ["scenario_id"] + [f"metric_3_truth_{threshold}" for threshold in thresholds] + [f"metric_3_observe_{threshold}" for threshold in thresholds]
    if len(args.llm_names) == 0:
        llm_names = [folder.name for folder in database_folder.iterdir() if folder.is_dir() and not folder.name.endswith("sample")]
    else:
        llm_names = args.llm_names
    for llm_name in llm_names:
        llm_folder = database_folder / llm_name
        results = DataFrame(columns=columns)
        for scenario in all_scenarios:
            if scenario == "test scenario":
                continue
            scenario_id = Scenario.get_index(scenario=scenario)
            if args.scenario_end != -1 and not (args.scenario_begin <= scenario_id <= args.scenario_end):
                continue
            model = CausalModel(scenario=scenario)
            evaluator = Evaluator(model=model, sample_n=args.sample_n_level_3, llm_name=llm_name, sample_num_threshold=args.sample_ratio_threshold_level_3)
            new_row = {"scenario_id": scenario_id}
            for threshold in thresholds:
                metric_3_truth = evaluator.evaluate_metric_3_scenario(threshold=threshold, by_truth=True)
                metric_3_observe = evaluator.evaluate_metric_3_scenario(threshold=threshold, by_truth=False)
                new_row[f"metric_3_truth_{threshold}"] = metric_3_truth
                new_row[f"metric_3_observe_{threshold}"] = metric_3_observe
            results.loc[len(results)] = new_row
            # metric for samples
            # evaluator.evaluate(by_truth_level_3=args.by_truth_level_3, 
            #                    by_truth_level_2=args.by_truth_level_2, 
            #                    least_needed_sample_level_2=args.least_needed_sample_level_2)
        results.to_csv(llm_folder / f"results_threshold_{llm_name}.csv")

if __name__ == '__main__':
    main()