from pathlib import Path
from pandas import DataFrame
from argparse import ArgumentParser, Namespace
from evaluate.evaluator import Evaluator
from util.scenario import Scenario
from util.interface import CausalModel

def parse_args() -> Namespace:
    parser = ArgumentParser(description="Evaluate")
    parser.add_argument("--sample_n_level_3", type=int, default=10, help="Number of False/True samples for each non_root variable for level 3")
    parser.add_argument("--by_truth_level_2", type=bool, default=True, help="Group by true value (True) or observed value (False)")
    parser.add_argument("--least_needed_sample_level_2", type=int, default=3, help="Least number of samples in a group for level 2")
    parser.add_argument("--by_truth_level_3", type=bool, default=True, help="Fit rule by ground truth roots or observed roots")
    parser.add_argument("--llm_names", type=list, 
                        default=["CogVideoX1.5-5b", "Hailuo", "Pika", "pyramid_flux", "CogVideoX-2b", "videocrafter2_base",
                                 "CogVideoX-5b", "HunyuanVideo", "Gen-3-Alpha", "Kling"],# ["Gen-3-Alpha", "CogVideoX1.5-5b", "Hailuo", "Pika", "pyramid_flux", "CogVideoX-2b", "videocrafter2_base", "CogVideoX-5b", "HunyuanVideo"],
                        help="Model to evaluate. Stay empty to evaluate all models")
    parser.add_argument("--scenario_begin", type=int, default=1, help="Ask scenarios from scenario_begin to scenario_end")
    parser.add_argument("--scenario_end", type=int, default=60, help="Ask scenarios from scenario_begin to scenario_end, -1 to ask all scenarios")
    parser.add_argument("--sample_ratio_threshold_level_3", type=float, default=0.001, help="Y with ratio of Y=1 < t or > 1-t will not be used")
    # parser.add_argument("--nan_as_fault_level_1", type=bool, default=False, help="See NAN as fault in level 1. If False, not use NAN data.")
    return parser.parse_args()

def main():
    args: Namespace= parse_args()
    database_folder = Path(__file__).resolve().parent / "database"
    all_scenarios = Scenario.get_all_scenarios()
    if len(args.llm_names) == 0:
        llm_names = [folder.name for folder in database_folder.iterdir() if folder.is_dir() and not folder.name.endswith("sample")]
    else:
        llm_names = args.llm_names
    for llm_name in llm_names:
        llm_folder = database_folder / llm_name
        results = DataFrame(columns=["scenario_id", "nan_cnt", "total_cnt", "metric1_all_ignore", "metric1_roots_ignore",
                                     "metric1_all_fault", "metric1_roots_fault", "metric2_by_truth",
                                      "metric2_by_observe", "metric3_by_truth", "metric3_by_observe",
                                      "level_1_nan_cnt", "level_1_total_cnt", "level_1_correct_cnt"])
        for scenario in all_scenarios:
            if scenario == "test scenario":
                continue
            scenario_id = Scenario.get_index(scenario=scenario)
            if args.scenario_end != -1 and not (args.scenario_begin <= scenario_id <= args.scenario_end):
                continue
            model = CausalModel(scenario=scenario)
            evaluator = Evaluator(model=model, sample_n=args.sample_n_level_3, llm_name=llm_name, sample_num_threshold=args.sample_ratio_threshold_level_3)
            nan_cnt, total_cnt, level_1_nan_cnt, level_1_total_cnt, level_1_correct_cnt = evaluator.count_nan_scenario()
            metric_1_all_ignore = evaluator.evaluate_metric_1_scenario(use_non_root=True, nan_as_fault=False)
            metric_1_roots_ignore = evaluator.evaluate_metric_1_scenario(use_non_root=False, nan_as_fault=False)
            metric_1_all_fault = evaluator.evaluate_metric_1_scenario(use_non_root=True, nan_as_fault=True)
            metric_1_roots_fault = evaluator.evaluate_metric_1_scenario(use_non_root=False, nan_as_fault=True)
            metric_2_by_truth = evaluator.evaluate_metric_2_scenario(by_truth=True, least_needed_sample=args.least_needed_sample_level_2)
            metric_2_by_observe = evaluator.evaluate_metric_2_scenario(by_truth=False, least_needed_sample=args.least_needed_sample_level_2)
            metric_3_by_truth = evaluator.evaluate_metric_3_scenario(by_truth=True)
            metric_3_by_observe = evaluator.evaluate_metric_3_scenario(by_truth=False)
            new_row = {
                "scenario_id": scenario_id,
                "nan_cnt": nan_cnt,
                "total_cnt": total_cnt,
                "metric1_all_ignore": metric_1_all_ignore,
                "metric1_roots_ignore": metric_1_roots_ignore,
                "metric1_all_fault": metric_1_all_fault,
                "metric1_roots_fault": metric_1_roots_fault,
                "metric2_by_truth": metric_2_by_truth,
                "metric2_by_observe": metric_2_by_observe,
                "metric3_by_truth": metric_3_by_truth,
                "metric3_by_observe": metric_3_by_observe,
                "level_1_nan_cnt": level_1_nan_cnt,
                "level_1_total_cnt": level_1_total_cnt,
                "level_1_correct_cnt": level_1_correct_cnt
            }
            results.loc[len(results)] = new_row
            # metric for samples
            # evaluator.evaluate(by_truth_level_3=args.by_truth_level_3, 
            #                    by_truth_level_2=args.by_truth_level_2, 
            #                    least_needed_sample_level_2=args.least_needed_sample_level_2)
        results.to_csv(llm_folder / f"results_{llm_name}.csv")

if __name__ == '__main__':
    main()
