from itertools import product
from collections import defaultdict
from random import randint
from random import sample as sample_without_replacement
from pathlib import Path
import shutil
import pandas as pd
from pandas import read_csv, DataFrame, Series
from util.interface import CausalModel
from util.scenario import Scenario

class Sampler:
    def __init__(self, model: CausalModel, llm_name: str="sample", cache=False):
        """
        cache = True only if when read_sample only!
        """
        self.model = model
        self.roots = model.roots
        if self.model.non_roots:
            non_roots = model.topo_sorted_non_roots()
            self.non_roots = non_roots
            self.variables = self.model.roots + non_roots
            self.full_table = self._get_full_table()
            self.full_table_index_dict = self._get_full_table_index_dict()

            self.llm_name = llm_name
            self.database_path = Path(__file__).resolve().parent.parent / 'database' / llm_name
            self.database_path.mkdir(parents=True, exist_ok=True)
            scenario_index = Scenario.get_index(self.model.scenario)
            self.scenario_path = self.database_path / str(scenario_index)
            self.scenario_path.mkdir(parents=True, exist_ok=True)

            self.cache = cache
            if cache:
                self.sample_df = self._get_all_samples_df()
                self.sample_df_text = self._get_sample_df_text_consistency()

    def enumerate_all_sample(self, repeat=1):
        """
        Enumerate all 2^n sample, return generator of samples (can be iterated using 
            "for sample in samples" or converted to list using list(samples)). Each
            sample is a dictionary of root variables with values in {True, False}.
        When repeat != 1, each sample is drawn "repeat" times, so there are 
            2^n * repeat samples in total.
        """
        roots = self.model.roots
        for values in product([False, True], repeat=len(roots)):
            sample = {root: value for root, value in zip(roots, values)}
            for _ in range(repeat):
                yield sample
    
    def generate_sample_for_rules(self, repeat=1) -> dict[str, tuple[list[int], list[int]]]:
        """
        For each non-root variable Y, generate "repeat" samples for both Y=1 and Y=0.
        The distribution of root variables X conditioning on Y=1 or Y=0 is uniform.
        Returns a dictionary with name of non-root variables Y as keys, and values are 
            a two-dimensional tuple, whose first element is a list of samples for Y=0,
            and the second element is a list of samples for Y=1. Each sample is a
            integer representing "sample_id" in all_samples.csv.
        When using this method, non-roots and rules in the model should not be None.
        Example:
        roots = ["a", "b", "c"]
        non_roots = ["d", "e"]
        rules = {"d": some rules, "e": some rules}
        repeat = 2
        Then the return value may be:
        
        """
        sample_index_dict = {}
        for non_root in self.non_roots:
            false_sample_index, true_sample_index = self._get_sample_index(non_root)
            exist_sample_num = len(false_sample_index)
            if exist_sample_num >= repeat:
                sample_index_dict[non_root] = (false_sample_index[:repeat], true_sample_index[:repeat])
            else:
                all_samples: DataFrame = self._get_all_samples_df()
                need_sample_num: int = repeat - exist_sample_num
                sample_0, sample_1 = self._sample_from_full_table(non_root, need_sample_num)
                exist_sample_index = set(false_sample_index + true_sample_index)
                false_new, all_samples = self._update_and_get_new_index(all_samples, sample_0, exist_sample_index)
                true_new, all_samples = self._update_and_get_new_index(all_samples, sample_1, exist_sample_index)
                self._write_all_samples(all_samples)
                false_all = false_sample_index + false_new
                true_all = true_sample_index + true_new
                self._write_sample_index(non_root, false_all, true_all)
                sample_index_dict[non_root] = (false_all, true_all)
                
        return sample_index_dict
    
    def get_sample_index(self, non_root, repeat=1) -> tuple[list[int], list[int]]:
        """
        A simple version of self.generate_sample_for_rules with more parameters.
        """
        sample_index_dict = self.generate_sample_for_rules(repeat)
        return sample_index_dict[non_root]

    def read_sample(self, sample_id: int, select_sample: str="rule") -> Series:
        """
        Return the row of sample whose sample_id is selected.
        Return value: pd.Series containing keys:
        sample_id: int,
        prompt: str | None; if prompt is None, it has not been generated.
        true_{name} for name in variables: bool; true values of all variables for the sample.
        observed_{name} for name in variables: bool | None; observed value for each variable.
            If None, it has not been questioned by LLM.
        If sample_select == "rule", then get sample from "sample_all.csv"
        If sample_select == "text", then get sample from "sample_text_consistency.csv"
        """
        if select_sample == "rule":
            if self.cache:
                sample_df = self.sample_df
            else:
                sample_df: DataFrame = self._get_all_samples_df()
        elif select_sample == "text":
            if self.cache:
                sample_df = self.sample_df_text
            else:
                sample_df: DataFrame = self._get_sample_df_text_consistency()
        else:
            raise ValueError(f"sample_select can only be rule or text, now it is {select_sample}")
        if len(sample_df) <= sample_id:
            raise ValueError(f"sample_id {sample_id} does not exist in scenario {self.model.scenario}.")
        return sample_df.loc[sample_id]
    
    def update_sample(self, sample_id: int, key: str, value, select_sample: str="rule"):
        """
        Update the row of sample whose sample_id is selected.
        Set row[key] = value, where key should be one of these:
        sample_id, prompt, true_{name}, observed_{name}
        where name is in self.variables.
        If sample_select == "rule", then get sample from "sample_all.csv"
        If sample_select == "text", then get sample from "sample_text_consistency.csv"
        """
        if select_sample == "rule":
            sample_df: DataFrame = self._get_all_samples_df()
        elif select_sample == "text":
            sample_df: DataFrame = self._get_sample_df_text_consistency()
        else:
            raise ValueError(f"sample_select can only be rule or text, now it is {select_sample}")
        if len(sample_df) <= sample_id:
            raise ValueError(f"sample_id {sample_id} does not exist in scenario {self.model.scenario}.")
        sample_df.loc[sample_id, key] = value
        if select_sample == "rule":
            self._write_all_samples(sample_df)
        elif select_sample == "text":
            self._write_sample_df_text_consistency(sample_df)

    def _get_full_table(self):
        # See ../test/test_sampler.py -> test_generate_all_sample_table for an example.
        m = len(self.non_roots)
        var_index = {variable: index for index, variable in enumerate(self.variables)}
        res = []
        for root_values in product([False, True], repeat=len(self.roots)):
            row = list(root_values) + [False] * m
            for non_root in self.non_roots:
                # Y=1 if one of rules holds
                for rule_term in self.model.rules[non_root]:
                    if all(row[var_index[parent]] == expected_value
                           for parent, expected_value in rule_term.items()):
                        row[var_index[non_root]] = True
            res.append(row)
        return res
    
    def _sample_from_full_table(self, non_root, repeat):
        roots = self.model.roots
        n = len(roots)
        col_index = self.variables.index(non_root)
        row_value_0, row_value_1 = [], []
        for i, row in enumerate(self.full_table):
            if row[col_index] == False:
                row_value_0.append(i)
            else:
                row_value_1.append(i)
        len_0, len_1 = len(row_value_0), len(row_value_1)
        row_sample_0 = [row_value_0[randint(0, len_0 - 1)] for _ in range(repeat)]
        row_sample_1 = [row_value_1[randint(0, len_1 - 1)] for _ in range(repeat)]
        sample_0 = [{root: value for root, value in zip(roots, self.full_table[row][:n])} for row in row_sample_0]
        sample_1 = [{root: value for root, value in zip(roots, self.full_table[row][:n])} for row in row_sample_1]
        return (sample_0, sample_1)
    
    def _get_sample_index(self, non_root):
        """
        Read ../database/{scenario_index}/sample_{non_root}.pkl (binary DataFrame),
        and get two lists of sample_id for non_root.
        Backward-compat: if .pkl not found, fall back to CSV.
        The first list contains sample_ids for non_root = False,
        and the second contains sample_ids for non_root = True.
        """
        file_path_pkl = self.scenario_path / f"sample_{non_root}.pkl"
        file_path_csv = self.scenario_path / f"sample_{non_root}.csv"
        if file_path_pkl.exists():
            sample_df = pd.read_pickle(file_path_pkl)
        elif file_path_csv.exists():
            sample_df = read_csv(file_path_csv, index_col=0)
        else:
            return [], []
        false_sample_indexs = sample_df["False"].tolist()
        true_sample_indexs = sample_df["True"].tolist()
        return false_sample_indexs, true_sample_indexs
    
    def _write_sample_index(self, non_root, false_index, true_index):
        file_path_pkl = self.scenario_path / f"sample_{non_root}.pkl"
        sample_df = DataFrame({"False": false_index, "True": true_index})
        sample_df.to_pickle(file_path_pkl)
    
    def _get_all_samples_df(self):
        file_path_pkl = self.scenario_path / 'all_samples.pkl'
        file_path_csv = self.scenario_path / 'all_samples.csv'
        if not file_path_pkl.exists() and not file_path_csv.exists():
            true_value_cols = ["true_" + name for name in self.variables]
            observed_value_cols = ["observed_" + name for name in self.variables]
            columns = ["sample_id", "prompt"] + true_value_cols + observed_value_cols
            return DataFrame(columns=columns)
        if file_path_pkl.exists():
            samples_df = pd.read_pickle(file_path_pkl)
        else:
            samples_df = read_csv(file_path_csv, index_col=0)
        return samples_df
    
    def _write_all_samples(self, all_samples_df: DataFrame):
        file_path_pkl = self.scenario_path / 'all_samples.pkl'
        all_samples_df.to_pickle(file_path_pkl)
    
    def _update_and_get_new_index(self, all_sample_df: DataFrame,
                                   new_samples: list[dict[str, bool]],
                                   exist_sample_index: set):
        res_index = []
        for sample in new_samples:
            names = sample.keys()
            for row_ind in all_sample_df.index:
                row: Series = all_sample_df.loc[row_ind]
                if row["sample_id"] in exist_sample_index:
                    continue
                if all(row["true_" + name] == sample[name] for name in names):
                    res_index.append(row["sample_id"])
                    exist_sample_index.add(row["sample_id"])
                    break
            else:
                new_index: int = len(all_sample_df)
                new_row = {"sample_id": new_index, "prompt": None}
                root_values = tuple([sample[root] for root in self.roots])
                all_values = self.full_table[self.full_table_index_dict[root_values]]
                new_row.update({"true_" + variable: value for variable, value in zip(self.variables, all_values)})
                new_row.update({"observed_" + variable: None for variable in self.variables})
                res_index.append(new_index)
                exist_sample_index.add(new_index)
                all_sample_df.loc[new_index] = new_row
        return res_index, all_sample_df


    def _get_full_table_index_dict(self) -> dict[tuple, int]:
        """
        Get the index for each combination of values of roots in the table.
        Example:
        roots = ["a", "b"]
        all_sample_table = [
            [False, False, False],
            [False, True, True],
            [True, False, False],
            [True, True, True]
        ]
        Explaination: combinations of values are (False, False), (False, True),
            (True, False), (True, True), with index 0, 1, 2, 3 respectively.
            So the return dict is:
        res = {(False, False): 0, (False, True): 1, (True, False): 2, (True, True): 3}
        """
        res = {}
        n = len(self.roots)
        for i, row in enumerate(self.full_table):
            res[tuple(row[:n])] = i
        return res
    
    def generate_sample_by_text_consistency(self, repeat=1) -> DataFrame:
        """
        Generate n="repeat" samples for testing text consistency.
        Returns a DataFrame of samples, also saved in sample_text_consistency.csv.
        If there are enough samples, they will be read from that file, not
            generated randomly.
        """
        sample_df = self._get_sample_df_text_consistency()
        if len(sample_df) >= repeat:
            return sample_df[:repeat]
        else:
            new_df = self._new_sample_df_text_consistency(start_id=len(sample_df), repeat=repeat - len(sample_df))
            full_df = pd.concat([sample_df, new_df], axis=0, ignore_index=True)
            self._write_sample_df_text_consistency(full_df)
            return full_df

    def _get_sample_df_text_consistency(self) -> DataFrame:
        file_path_pkl = self.scenario_path / f"sample_text_consistency.pkl"
        file_path_csv = self.scenario_path / f"sample_text_consistency.csv"
        if not file_path_pkl.exists() and not file_path_csv.exists():
            true_value_cols = ["true_" + name for name in self.variables]
            observed_value_cols = ["observed_" + name for name in self.variables]
            columns = ["sample_id", "prompt"] + true_value_cols + observed_value_cols + ["metric1"]
            return DataFrame(columns=columns)
        if file_path_pkl.exists():
            sample_df = pd.read_pickle(file_path_pkl)
        else:
            sample_df = read_csv(file_path_csv, index_col=0)
        return sample_df
    
    def get_sample_indexs_text_consistency(self) -> list[int]:
        sample_df = self._get_sample_df_text_consistency()
        return sample_df["sample_id"].tolist()
    
    def _write_sample_df_text_consistency(self, sample_df):
        file_path_pkl = self.scenario_path / f"sample_text_consistency.pkl"
        sample_df.to_pickle(file_path_pkl)

    def _new_sample_df_text_consistency(self, start_id=0, repeat=1) -> DataFrame:
        full_table_length = len(self.full_table)
        new_indexs = [randint(0, full_table_length - 1) for _ in range(repeat)]
        df_dict = {}
        for i, name in enumerate(self.variables):
            df_dict["true_" + name] = [self.full_table[index][i] for index in new_indexs]
        df_dict["sample_id"] = list(range(start_id, start_id + repeat))
        true_value_cols = ["true_" + name for name in self.variables]
        observed_value_cols = ["observed_" + name for name in self.variables]
        columns = ["sample_id", "prompt"] + true_value_cols + observed_value_cols + ["metric1"]
        new_sample_df = DataFrame(df_dict, columns=columns)
        return new_sample_df
    
    def generate_sample_gen_consistency(self, repeat=3, sample_num=1, mode="sample efficiency") -> DataFrame:
        if mode == "sample efficiency":
            return self._generate_sample_gen_consistency_sample_efficiency(repeat=repeat, sample_num=sample_num)
        elif mode == "uniform":
            return self._generate_sample_gen_consistency_uniform(repeat=repeat, sample_num=sample_num)
        else:
            raise ValueError(f"mode {mode} not implemented.")
            
    def _generate_sample_gen_consistency_uniform(self, repeat: int, sample_num: int) -> DataFrame:
        group_indexs = self.get_sample_index_level_2()
        group_dfs = list(group_indexs.groupby("group_id"))
        if len(group_dfs) >= sample_num and len(group_dfs[-1][1]) >= repeat:
            return group_indexs
        max_comb = 2 ** len(self.roots)
        sample_num = min(sample_num, max_comb)
        selected_rows = sample_without_replacement(range(max_comb), sample_num)
        selected_values = [self.full_table[row] for row in selected_rows]
        all_sample_df = self._get_all_samples_df()
        new_index = len(all_sample_df)
        new_sample_df = DataFrame(columns=all_sample_df.columns)
        group_indexs = DataFrame(columns=["group_id", "sample_id"])
        for group_id, values in enumerate(selected_values):
            exist_cnt = 0
            for sample_id in all_sample_df.index:
                row = all_sample_df.loc[sample_id]
                if all(row["true_" + root] == values[root_i] for root_i, root in enumerate(self.roots)):
                    exist_cnt += 1
                    group_indexs.loc[len(group_indexs)] = {"group_id": group_id, "sample_id": sample_id}
                    if exist_cnt == repeat:
                        break
            for _ in range(repeat - exist_cnt):
                new_row = {"true_" + name: value for name, value in zip(self.variables, values)}
                new_row["sample_id"] = new_index
                new_sample_df.loc[len(new_sample_df)] = new_row
                group_indexs.loc[len(group_indexs)] = {"group_id": group_id, "sample_id": new_index}
                new_index += 1
        sample_df = pd.concat([all_sample_df, new_sample_df], axis=0, ignore_index=True)
        self._write_all_samples(sample_df)
        self._write_sample_index_level_2(group_indexs)
        return group_indexs
    
    def _generate_sample_gen_consistency_sample_efficiency(self, repeat: int, sample_num: int) -> DataFrame:
        """
        repeat: each sample in level 2 has n="repeat" numbers of same samples.
        sample_num: ensure there are at least "sample_num" groups of level 2 samples.
        """
        sample_df = self._get_all_samples_df()
        new_samples, group_indexs = self._construct_level_2_samples(repeat, sample_num, sample_df, len(sample_df))
        sample_df = pd.concat([sample_df, new_samples], axis=0, ignore_index=True)
        self._write_all_samples(sample_df)
        self._write_sample_index_level_2(group_indexs)
        return group_indexs

    def _construct_level_2_samples(self, repeat: int, sample_num: int,
                       sample_df: DataFrame, start_id=0) -> tuple[DataFrame, DataFrame]:
        value_indexs = defaultdict(list)
        for row_id in sample_df.index:
            row = sample_df.loc[row_id]
            value_indexs[tuple(row["true_" + var] for var in self.variables)].append(row["sample_id"])
        ordered_values = sorted([(value, len(ids)) for value, ids in value_indexs.items()], key=lambda x: -x[1])
        new_samples = DataFrame(columns=sample_df.columns)
        group_indexs = DataFrame(columns=["group_id", "sample_id"])
        group_id = 0
        new_id = start_id
        for value, cnt in ordered_values[:sample_num]:
            selected_indexs = value_indexs[value]
            for selected_index in selected_indexs:
                group_indexs.loc[len(group_indexs)] = {"group_id": group_id, "sample_id": selected_index}
            if cnt < repeat:
                for _ in range(repeat - cnt):
                    new_row = {"true_" + name: value[i] for i, name in enumerate(self.variables)}
                    new_row["sample_id"] = new_id
                    group_indexs.loc[len(group_indexs)] = {"group_id": group_id, "sample_id": new_id}
                    new_samples.loc[len(new_samples)] = new_row
                    new_id += 1
            group_id += 1
        return new_samples, group_indexs
    
    def get_sample_index_level_2(self) -> DataFrame:
        file_path_pkl = self.scenario_path / f"sample_index_level_2.pkl"
        file_path_csv = self.scenario_path / f"sample_index_level_2.csv"
        if not file_path_pkl.exists() and not file_path_csv.exists():
            return DataFrame(columns=["group_id", "sample_id"])
        if file_path_pkl.exists():
            group_indexs = pd.read_pickle(file_path_pkl)
        else:
            group_indexs = read_csv(file_path_csv, index_col=0)
        return group_indexs

    def _write_sample_index_level_2(self, group_indexs: DataFrame):
        file_path_pkl = self.scenario_path / f"sample_index_level_2.pkl"
        group_indexs.to_pickle(file_path_pkl)
        
    @classmethod
    def copy_samples(cls, target_llm_name: str, source_llm_name: str="sample"):
        database_path = Path(__file__).resolve().parent.parent / 'database'
        source_folder = database_path / source_llm_name
        target_folder = database_path / target_llm_name
        if not source_folder.exists():
            raise ValueError("Copy failed: source folder not exist.")
        target_folder.mkdir(parents=True, exist_ok=True)
        shutil.copytree(source_folder, target_folder, dirs_exist_ok=True)

    def get_sample_df_have_observation(self) -> DataFrame:
        all_sample_df: DataFrame = self._get_all_samples_df()
        result_df = all_sample_df.dropna(subset=["observed_" + name for name in self.variables])
        return result_df
    
    def get_sample_df(self, select_sample) -> DataFrame:
        if select_sample == "rule":
            return self._get_all_samples_df()
        elif select_sample == "text":
            return self._get_sample_df_text_consistency()
        else:
            raise ValueError(f"select_sample = {select_sample}, not in {{rule, text}}")
