from pathlib import Path
from collections import defaultdict
import numpy as np
import pandas as pd
from arch.bootstrap import IIDBootstrap
from sample.sampler import Sampler
from util.interface import CausalModel
from util.scenario import Scenario
import random

MAX_SAMPLE_N_LEVEL_3 = 50

class BootStraper:
    def __init__(self, llm_name: str, include_nan=True, random_seed=None):
        self.llm_name = llm_name
        self.database_path = Path(__file__).resolve().parent.parent / 'database' / llm_name
        self.database_path.mkdir(parents=True, exist_ok=True)
        self.metrics = {
            "metric_1_all_ignore": [],
            "metric_1_roots_ignore": [],
            "metric_1_all_fault": [],
            "metric_1_roots_fault": [],
            "metric_2_truth": [],
            "metric_2_observe": [],
            "metric_3_truth": [],
            "metric_3_observe": []
        }
        self.include_nan = include_nan
        self.random_seed = random_seed
        if self.random_seed is not None:
            random.seed(self.random_seed)
            np.random.seed(self.random_seed)
    
    def add_samples(self, scenario: str, sample_n_level_1: int, group_n_level_2: int,
                    group_size_level_2: int, sample_n_level_3: int, threshold_level_3=-1):
        model = CausalModel(scenario=scenario)
        sampler = Sampler(model=model, llm_name=self.llm_name)
        # metric_1
        sample_ids = sampler.get_sample_indexs_text_consistency()
        metric_1_roots_fault = []
        metric_1_all_fault = []
        metric_1_roots_ignore = []
        metric_1_all_ignore = []
        if len(sample_ids) < sample_n_level_1:
            raise ValueError(f"The number of level 1 samples is {len(sample_ids)}, not enough for needed {sample_n_level_1}.")
        if self.random_seed is not None:
            select_ids = random.sample(sample_ids, sample_n_level_1)
        else:
            select_ids = sample_ids[:sample_n_level_1]
        for sample_id in select_ids:
            row = sampler.read_sample(sample_id=sample_id, select_sample="text")
            #if any(pd.isna(row["observed_" + name]) for name in model.variables):
            #    raise ValueError(f"Sample {sample_id} not observed for scenario '{scenario}'.")
            if self.include_nan:
                points_all = [int((not pd.isna(row["observed_" + var]))
                                and row["true_" + var] == row["observed_" + var]) 
                                for var in model.variables]
                points_roots = [int((not pd.isna(row["observed_" + var]))
                                and row["true_" + var] == row["observed_" + var]) 
                                for var in model.roots]
                points_all_ignore = []
                points_roots_ignore = []
                for var in model.variables:
                    if not pd.isna(row["observed_" + var]):
                        points_all_ignore.append(int(row["true_" + var] == row["observed_" + var]))
                for var in model.roots:
                    if not pd.isna(row["observed_" + var]):
                        points_roots_ignore.append(int(row["true_" + var] == row["observed_" + var]))

            else:
                raise NotImplementedError("include_nan should be true.")
            metric_1_all_fault.append(sum(points_all) / len(points_all))
            metric_1_roots_fault.append(sum(points_roots) / len(points_roots))
            if points_all_ignore:
                metric_1_all_ignore.append(sum(points_all_ignore) / len(points_all_ignore))
            if points_roots_ignore:
                metric_1_roots_ignore.append(sum(points_roots_ignore) / len(points_roots_ignore))
        # metric_2
        metric_2_truth = []
        metric_2_observe = []
        group_indexs = sampler.get_sample_index_level_2()
        groups_truth = group_indexs.groupby("group_id")
        groups_truth = sorted(list(groups_truth), reverse=True, key=lambda x: len(x[1]))
        compare_vars_truth = model.variables
        sample_df = sampler.get_sample_df_have_observation()
        groups_observe = sample_df.groupby(["observed_" + root for root in model.roots])
        groups_observe = sorted(list(groups_observe), reverse=True, key=lambda x: len(x[1]))
        compare_vars_observe = model.non_roots
        for metrics, groups, compare_vars in zip([metric_2_truth, metric_2_observe],
                                                 [groups_truth, groups_observe],
                                                 [compare_vars_truth, compare_vars_observe]):
            if self.random_seed is not None:
                select_groups = random.sample(groups, min(group_n_level_2, len(groups)))
            else:
                select_groups = groups[:group_n_level_2]
            for group_id, sub_df in select_groups:
                if len(sub_df) < group_size_level_2:
                    continue
                raw_sample_ids = sub_df["sample_id"].tolist()
                sample_ids = []
                for sample_id in raw_sample_ids: 
                    row = sampler.read_sample(sample_id)
                    if any(pd.isna(row["observed_" + var]) for var in model.variables):
                        continue # ignore samples containing nan
                    sample_ids.append(sample_id)
                if len(sample_ids) <= 1:
                    continue # ignore groups with number of samples <= 1
                dis_dict = defaultdict(list)
                for var in compare_vars:
                    values = [int(sampler.read_sample(sample_id)["observed_" + var]) for sample_id in sample_ids]
                    mean_value = sum(values) / len(values)
                    for sample_id, value in zip(sample_ids, values):
                        dis_dict[sample_id].append((value - mean_value) ** 2)
                mean_dis = [sum(dis_dict[sample_id]) / len(dis_dict[sample_id]) for sample_id in sample_ids]
                metrics.append(sum(mean_dis) / len(mean_dis))
        # metric_3
        metric_3_truth = []
        metric_3_observe = []
        for metrics, by_truth in zip([metric_3_truth, metric_3_observe], [True, False]):
            for non_root in model.non_roots:
                if self.random_seed is not None:
                    sample_index_all = sampler.get_sample_index(non_root=non_root, repeat=MAX_SAMPLE_N_LEVEL_3)
                    false_index_all, true_index_all = sample_index_all
                    false_index = random.sample(false_index_all, sample_n_level_3)
                    true_index = random.sample(true_index_all, sample_n_level_3)
                else:
                    sample_index = sampler.get_sample_index(non_root=non_root, repeat=sample_n_level_3)
                    false_index, true_index = sample_index
                correct_cnt = 0
                total_cnt = 0
                for sample_id in false_index + true_index:
                    row = sampler.read_sample(sample_id=sample_id)
                    if any(pd.isna(row["observed_" + var]) for var in model.variables):
                        continue # ignore samples with nan
                    if by_truth:
                        correct_cnt += row["true_" + non_root] == row["observed_" + non_root]
                    else:
                        root_value = tuple(row["observed_" + var] for var in model.roots)
                        value_index = sampler.full_table_index_dict[root_value]
                        true_values = sampler.full_table[value_index]
                        non_root_pos = model.variables.index(non_root)
                        true_value = true_values[non_root_pos]
                        correct_cnt += true_value == row["observed_" + non_root]
                    total_cnt += 1
                if total_cnt > 0: # ignore situations with zero sample
                    if threshold_level_3 == -1:
                        metrics.append(correct_cnt / total_cnt)
                    else:
                        metrics.append(int(correct_cnt / total_cnt >= threshold_level_3))
        # add results
        self.metrics["metric_1_all_fault"] += metric_1_all_fault
        self.metrics["metric_1_roots_fault"] += metric_1_roots_fault
        self.metrics["metric_1_all_ignore"] += metric_1_all_ignore
        self.metrics["metric_1_roots_ignore"] += metric_1_roots_ignore
        self.metrics["metric_2_truth"] += metric_2_truth
        self.metrics["metric_2_observe"] += metric_2_observe
        self.metrics["metric_3_truth"] += metric_3_truth
        self.metrics["metric_3_observe"] += metric_3_observe

    def add_samples_small_scenarios(self, scenario_num=15, threshold_level_3=-1):
        sample_n_level_1 = 10
        group_n_level_2 = 5
        group_size_level_2 = 3
        sample_n_level_3 = 10
        all_scenarios = Scenario.get_all_scenarios()
        for scenario in all_scenarios:
            scenario_id = Scenario.get_index(scenario)
            if 1 <= scenario_id <= scenario_num:
                self.add_samples(scenario=scenario, sample_n_level_1=sample_n_level_1,
                                 group_n_level_2=group_n_level_2, group_size_level_2=group_size_level_2,
                                 sample_n_level_3=sample_n_level_3, threshold_level_3=threshold_level_3)
                
    def bootstrap_confint(self, method='percentile', reps=1000, size=0.95):
        results = {}
        for key, scores in self.metrics.items():
            bs = IIDBootstrap(np.array(scores))
            conf_int = bs.conf_int(
                np.mean,
                method=method,
                reps=reps,
                size=size
            )
            results[key] = conf_int
        return results
    
    def bootstrap_std(self, reps=1000, level_3_only=False):
        results = {}
        for key, scores in self.metrics.items():
            if level_3_only and not key.startswith("metric_3"):
                continue
            bs = IIDBootstrap(np.array(scores))
            bs_res = bs.apply(np.mean, reps=reps)
            std_dev = np.std(bs_res)
            results[key] = std_dev
        return results
    
    def bootstrap_mean(self, reps=1000):
        results = {}
        for key, scores in self.metrics.items():
            bs = IIDBootstrap(np.array(scores))
            bs_res = bs.apply(np.mean, reps=reps)
            std_dev = np.mean(bs_res)
            results[key] = std_dev
        return results
