from pathlib import Path
import pandas as pd
from collections import defaultdict
from sample.sampler import Sampler
from util.interface import CausalModel
from util.scenario import Scenario
from handle.video_asker import VideoAskerV1
import json

class Evaluator:
    def __init__(self, model: CausalModel, sample_n=1, llm_name: str="sample", 
                 include_nan=True, rebalance=True, sample_num_threshold=0.2):
        """
        include_nan: whether consider "nan" in video asker.
        """
        self.model = model
        self.roots = model.roots
        if self.model.non_roots:
            non_roots = model.topo_sorted_non_roots()
            self.non_roots = non_roots
            self.variables = self.model.roots + non_roots
        self.sampler = Sampler(model=model, llm_name=llm_name, cache=True)
        self.sample_n = sample_n
        self.include_nan = include_nan
        self.rebalance = rebalance
        self.sample_num_threshold = sample_num_threshold

        self.llm_name = llm_name
        self.database_path = Path(__file__).resolve().parent.parent / 'database' / llm_name
        self.database_path.mkdir(parents=True, exist_ok=True)
        scenario_index = Scenario.get_index(self.model.scenario)
        self.scenario_path = self.database_path / str(scenario_index)
        self.scenario_path.mkdir(parents=True, exist_ok=True)
        self.scenario_index = scenario_index

        self.sample_results_path = self.scenario_path / "sample_results"
        self.sample_results_path.mkdir(exist_ok=True)

    def evaluate_metric_3_scenario(self, threshold=-1, by_truth=True) -> float:
        """
        by_truth: True: true_{name} == observed_{name} for name in non_roots
            False: If roots == observed_roots, whether non_roots == observed_non_roots
        """
        accs = [] # accuracy
        for non_root in self.non_roots:
            sample_index = self.sampler.get_sample_index(non_root=non_root, repeat=self.sample_n)
            false_index, true_index = sample_index
            n = len(false_index)
            correct_cnt = 0
            total_cnt = 0
            truth_observed_pairs = []
            for sample_id in false_index + true_index:
                row = self.sampler.read_sample(sample_id=sample_id)
                if any(pd.isna(row["observed_" + var]) for var in self.variables):
                    continue # ignore sample with nan
                if by_truth:
                    correct_cnt += row["true_" + non_root] == row["observed_" + non_root]
                else:
                    rule_term = self.model.rules[non_root]
                    isTrue = 0
                    for rule_t in rule_term:
                        if all(row["observed_" + parent] == expected_value
                                for parent, expected_value in rule_t.items()):
                            isTrue = 1
                            break
                    # root_value = tuple(row["observed_" + var] for var in self.roots)
                    # value_index = self.sampler.full_table_index_dict[root_value]
                    # true_values = self.sampler.full_table[value_index]
                    # non_root_pos = self.variables.index(non_root)
                    # true_value = true_values[non_root_pos]
                    truth_observed_pairs.append((isTrue, row["observed_" + non_root]))
                total_cnt += 1
            if not by_truth:
                if len(truth_observed_pairs) == 0:
                    print(self.llm_name, self.model.scenario)
                    continue
                num_one_ratio = sum(pair[0] == 1 for pair in truth_observed_pairs) / len(truth_observed_pairs)
                if num_one_ratio < self.sample_num_threshold or num_one_ratio > 1 - self.sample_num_threshold:
                    continue
            if total_cnt > 0:
                if not by_truth:
                    correct_cnt = sum(pair[0] == 1 and pair[1] == 1 for pair in truth_observed_pairs) / \
                        (2 * sum(pair[0] == 1 for pair in truth_observed_pairs) / len(truth_observed_pairs)) + \
                        sum(pair[0] == 0 and pair[1] == 0 for pair in truth_observed_pairs) / \
                        (2 * sum(pair[0] == 0 for pair in truth_observed_pairs) / len(truth_observed_pairs))
                    assert len(truth_observed_pairs) == total_cnt
                if threshold == -1:
                    accs.append(correct_cnt / total_cnt)
                else:
                    accs.append(int(correct_cnt / total_cnt >= threshold))
        return sum(accs) / len(accs) if accs else pd.NA
    
    def evaluate_metric_2_scenario(self, by_truth=True, least_needed_sample=3) -> float:
        """
        by_truth == True: use groups sampled for level 2. Every sample in the same
            group has the same ground truth value for every variable.
        by_truth == False: ignore the groups sampled for level 2. Instead, construct
            groups by observed values. Every sample in the same group has the same
            observed value for every **root variable**.
        least_needed_sample: when by_truth == False, only groups with number of samples
            greater or equal than least_needed_sample are considered in calculating metric 2.
        """
        if by_truth:
            group_indexs = self.sampler.get_sample_index_level_2()
            groups = group_indexs.groupby("group_id")
            compare_vars = self.non_roots
        else:
            sample_df = self.sampler.get_sample_df_have_observation()
            groups = sample_df.groupby(["observed_" + root for root in self.roots])
            compare_vars = self.non_roots
        estimators = []
        for group_id, sub_df in groups:
            if len(sub_df) < least_needed_sample:
                continue
            raw_sample_ids = sub_df["sample_id"].tolist()
            sample_ids = []
            for sample_id in raw_sample_ids: 
                row = self.sampler.read_sample(sample_id)
                if any(pd.isna(row["observed_" + var]) for var in self.variables):
                    continue # ignore samples containing nan
                sample_ids.append(sample_id)
            if len(sample_ids) <= 1:
                continue # ignore groups with number of samples <= 1
            dis_dict = defaultdict(list)
            for var in compare_vars:
                # print(sample_ids) # debug
                values = [int(self.sampler.read_sample(sample_id)["observed_" + var]) for sample_id in sample_ids]
                mean_value = sum(values) / len(values)
                for sample_id, value in zip(sample_ids, values):
                    dis_dict[sample_id].append((value - mean_value) ** 2)
            mean_dis = [sum(dis_dict[sample_id]) / len(dis_dict[sample_id]) for sample_id in sample_ids]
            estimators.append(sum(mean_dis) / len(mean_dis))
        return sum(estimators) / len(estimators) if estimators else pd.NA

    def evaluate_metric_1_scenario(self, use_non_root=True, nan_as_fault=True):
        estimators = []
        if use_non_root:
            compare_vars = self.variables
            select_sample = "text"
            sample_ids = self.sampler.get_sample_indexs_text_consistency()
        else:
            compare_vars = self.roots
            select_sample = "rule"
            sample_ids = self.sampler._get_all_samples_df()["sample_id"].tolist()
        for sample_id in sample_ids:
            row = self.sampler.read_sample(sample_id=sample_id, select_sample=select_sample)
            if self.include_nan:
                points = []
                for var in compare_vars:
                    if pd.isna(row["observed_" + var]):
                        if nan_as_fault:
                            points.append(0)
                    else:
                        points.append(int(row["true_" + var] == row["observed_" + var]))
            else:
                points = [int(row["true_" + var] == row["observed_" + var]) for var in compare_vars]
            if len(points) > 0:
                mean_point = sum(points) / len(points)
                estimators.append(mean_point)
        return sum(estimators) / len(estimators) if estimators else pd.NA
    
    def evaluate_metric_3(self, by_truth=True):
        """
        by_truth: True: true_{name} == observed_{name}
            False: If roots == observed_roots, whether non_roots == observed_non_roots
        """
        index_dict = {}
        all_indexs = []
        for non_root in self.non_roots:
            sample_index = self.sampler.get_sample_index(non_root=non_root, repeat=self.sample_n)
            false_index, true_index = sample_index
            index_dict[non_root] = set(false_index + true_index)
            all_indexs += false_index + true_index
        all_indexs = sorted(list(set(all_indexs)))
        # calculate g_j: number of samples that f(\hat{pa}(Y_j))=1 for each Y_j
        gs = {}
        total_cnts = {}
        for non_root in self.non_roots:
            if by_truth:
                break # if by_truth, it has no need to calculate g_j
            gs[non_root] = 0
            total_cnts[non_root] = 0
            for sample_id in list(index_dict[non_root]):
                row = self.sampler.read_sample(sample_id=sample_id)
                if any(pd.isna(row["observed_" + var]) for var in self.variables):
                    continue # ignore sample with nan
                total_cnts[non_root] += 1
                rule_term = self.model.rules[non_root]
                for rule_t in rule_term:
                    if all(row["observed_" + parent] == expected_value
                            for parent, expected_value in rule_t.items()):
                        gs[non_root] += 1 # this sample is expected to be 1, add 1 to g_j
                        break
        # print(gs)
        # print(total_cnts)
        for sample_id in all_indexs:
            total_cnt, correct_cnt = 0, 0
            row = self.sampler.read_sample(sample_id=sample_id)
            if any(pd.isna(row["observed_" + var]) for var in self.variables):
                continue # ignore sample with nan
            for non_root in self.non_roots:
                if sample_id in index_dict[non_root]:
                    if by_truth:
                        correct_cnt += row["true_" + non_root] == row["observed_" + non_root]
                        total_cnt += 1
                    else:
                        if total_cnts[non_root] == 0 or not (self.sample_num_threshold <= gs[non_root] / total_cnts[non_root] <= 1 - self.sample_num_threshold):
                            continue
                        rule_term = self.model.rules[non_root]
                        isTrue = 0
                        for rule_t in rule_term:
                            if all(row["observed_" + parent] == expected_value
                                    for parent, expected_value in rule_t.items()):
                                isTrue = 1
                                break
                        if isTrue:
                            weight = total_cnts[non_root] / 2 / gs[non_root]
                        else:
                            weight = total_cnts[non_root] / 2 / (total_cnts[non_root] - gs[non_root])
                        correct_cnt += weight * int(row["observed_" + non_root] == isTrue)
                        total_cnt += weight
            metric = correct_cnt / total_cnt if total_cnt > 0 else pd.NA
            # self.sampler.update_sample(sample_id=sample_id, key="metric3", value=metric)
            if by_truth:
                metric_name = "metric_3_truth"
            else:
                metric_name = "metric_3_observe"
            self._update_sample_result(sample_id=sample_id, key=metric_name, value=metric)
    
    def evaluate_metric_2(self, by_truth=True, least_needed_sample=3):
        """
        by_truth == True: use groups sampled for level 2. Every sample in the same
            group has the same ground truth value for every variable.
        by_truth == False: ignore the groups sampled for level 2. Instead, construct
            groups by observed values. Every sample in the same group has the same
            observed value for every **root variable**.
        least_needed_sample: when by_truth == False, only groups with number of samples
            greater or equal than least_needed_sample are considered in calculating metric 2.
        """
        if by_truth:
            group_indexs = self.sampler.get_sample_index_level_2()
            groups = group_indexs.groupby("group_id")
            compare_vars = self.non_roots
        else:
            sample_df = self.sampler.get_sample_df_have_observation()
            groups = sample_df.groupby(["observed_" + root for root in self.roots])
            compare_vars = self.non_roots
        for group_id, sub_df in groups:
            if len(sub_df) < least_needed_sample:
                continue
            raw_sample_ids = sub_df["sample_id"].tolist()
            sample_ids = []
            for sample_id in raw_sample_ids: 
                row = self.sampler.read_sample(sample_id)
                if any(pd.isna(row["observed_" + var]) for var in self.variables):
                    continue # ignore samples containing nan
                sample_ids.append(sample_id)
            if len(sample_ids) <= 1:
                continue # ignore groups with number of samples <= 1
            dis_dict = defaultdict(list)
            for var in compare_vars:
                values = [int(self.sampler.read_sample(sample_id)["observed_" + var]) for sample_id in sample_ids]
                mean_value = sum(values) / len(values)
                for sample_id, value in zip(sample_ids, values):
                    dis_dict[sample_id].append((value - mean_value) ** 2)
            for sample_id in sample_ids:
                mean_dis = sum(dis_dict[sample_id]) / len(dis_dict[sample_id])
                # self.sampler.update_sample(sample_id=sample_id, key="metric2", value=mean_dis)
                metric_name = "metric_2_truth" if by_truth else "metric_2_observe"
                self._update_sample_result(sample_id=sample_id, key=metric_name, value=mean_dis)
    
    def evaluate_metric_1(self, nan_as_fault=True):
        sample_ids = self.sampler._get_all_samples_df()["sample_id"].tolist()
        compare_vars = self.roots
        metric_name = "metric_1_roots_fault" if nan_as_fault else "metric_1_roots_ignore"
        for sample_id in sample_ids:
            row = self.sampler.read_sample(sample_id=sample_id)
            points = []
            for var in compare_vars:
                if pd.isna(row["observed_" + var]):
                    if nan_as_fault:
                        points.append(0)
                else:
                    points.append(int(row["true_" + var] == row["observed_" + var]))
            # if self.include_nan:
            #     points = [int((not pd.isna(row["observed_" + var]))
            #               and row["true_" + var] == row["observed_" + var]) 
            #               for var in self.variables]
            # else:
            #     points = [int(row["true_" + var] == row["observed_" + var]) for var in self.variables]
            mean_point = sum(points) / len(points) if points else pd.NA
            # self.sampler.update_sample(sample_id=sample_id, key="metric1", value=mean_point, select_sample="text")
            
            self._update_sample_result(sample_id=sample_id, key=metric_name, value=mean_point)
    
    def evaluate(self, least_needed_sample_level_2=3):
        self.evaluate_metric_3(by_truth=True)
        self.evaluate_metric_3(by_truth=False)
        self.evaluate_metric_2(by_truth=True, least_needed_sample=least_needed_sample_level_2)
        self.evaluate_metric_2(by_truth=False, least_needed_sample=least_needed_sample_level_2)
        self.evaluate_metric_1(nan_as_fault=True)
        self.evaluate_metric_1(nan_as_fault=False)
        self._gather_results()
        self._add_video_paths()

    def count_nan_scenario(self) -> tuple[int, int]:
        handle_indexs = []
        for non_root in self.non_roots:
            sample_index = self.sampler.get_sample_index(non_root=non_root, repeat=self.sample_n)
            false_index, true_index = sample_index
            handle_indexs += false_index + true_index
        group_indexs = self.sampler.get_sample_index_level_2()
        handle_indexs += group_indexs["sample_id"].tolist()
        handle_indexs = sorted(list(set(handle_indexs)))
        total_cnt = len(self.variables) * len(handle_indexs)
        nan_cnt = 0
        for sample_id in handle_indexs:
            row = self.sampler.read_sample(sample_id=sample_id)
            for var in self.variables:
                if pd.isna(row["observed_" + var]):
                    nan_cnt += 1
        sample_ids = self.sampler.get_sample_indexs_text_consistency()
        level_1_total_cnt = len(self.variables) * len(sample_ids)
        total_cnt += level_1_total_cnt
        level_1_nan_cnt = 0
        level_1_correct_cnt = 0
        for sample_id in sample_ids:
            row = self.sampler.read_sample(sample_id=sample_id, select_sample="text")
            for var in self.variables:
                if pd.isna(row["observed_" + var]):
                    nan_cnt += 1
                    level_1_nan_cnt += 1
                elif row["observed_" + var] == row["true_" + var]:
                    level_1_correct_cnt += 1
        return nan_cnt, total_cnt, level_1_nan_cnt, level_1_total_cnt, level_1_correct_cnt
    
    def _update_sample_result(self, sample_id, key, value, select_sample="rule"):
        if select_sample == "rule":
            save_name = self.sample_results_path / f"{sample_id}.json"
        else:
            save_name = self.sample_results_path / f"{sample_id}_text.json"
        if not save_name.exists():
            result = {
                "metric_1_roots_ignore": "nan",
                "metric_1_roots_fault": "nan",
                "metric_2_truth": "nan",
                "metric_2_observe": "nan",
                "metric_3_truth": "nan",
                "metric_3_observe": "nan"
            }
        else:
            with open(save_name, "r") as fp:
                result = json.load(fp)
        if pd.isna(value):
            value = "nan"
        result[key] = value
        with open(save_name, "w") as fp:
            json.dump(result, fp, indent=4)

    def _gather_results(self):
        sample_df = self.sampler._get_all_samples_df()
        del sample_df["metric2"]
        del sample_df["metric3"]
        sample_ids = sample_df["sample_id"].tolist()
        results = {}
        for sample_id in sample_ids:
            result_file = self.sample_results_path / f"{sample_id}.json"
            with open(result_file, "r") as fp:
                results[sample_id] = json.load(fp)
        metrics = results[0].keys()
        for metric in metrics:
            col_metric = {}
            for sample_id in sample_ids:
                col_metric[sample_id] = results[sample_id][metric]
            sample_df[metric] = col_metric
        output_file = self.scenario_path / "evaluate_results.csv"
        sample_df.to_csv(output_file)

    def _add_video_paths(self):
        evaluate_results_file = self.scenario_path / "evaluate_results.csv"
        evaluate_results = pd.read_csv(evaluate_results_file, index_col=0)
        video_asker = VideoAskerV1(model=self.model, llm_name=self.llm_name)
        video_paths_dict = video_asker.get_video_paths(select_sample="rule")
        evaluate_results["video_path"] = video_paths_dict
        evaluate_results.to_csv(evaluate_results_file)
