import pandas as pd
import numpy as np
import time
import itertools
from dataclasses import dataclass
from Utils.PMB_learn import PMB_Learner
from Utils.PMB_CI_test import CItest_method
from Utils.Est_Effect import est_reg_con
@dataclass
class PAGEdge:
    TAIL = 3   
    ARROW = 2      
    CIRCLE = 1    
    NONE = 0      


class LCS:
        def __init__(
        self,
        continuous_data: pd.DataFrame,
        treatment: str,
        outcome: str,
        ci_type: str,
        max_K: int = 8,
        alpha: float = 0.01,
        Max_time: int = None,
    ):
            self.treatment_name = str(treatment)
            self.outcome_name = str(outcome)
            self.ci_type = ci_type
            self.max_K = max_K
            self.alpha = alpha
            self.continuous_data = continuous_data
            self.latent_nodes = [c for c in continuous_data.columns if str(c).startswith("L")]
            self.observed_data = continuous_data.drop(columns=self.latent_nodes, errors="ignore")
            self.CI_test = CItest_method(continuous_data, method_type=ci_type, alpha=alpha)
            self.Nodes_dict = {}
            self.treatment_node = None
            self.outcome_node = None
            self.learner = None
            self.pag = None
            self.G = None
            self.Pa = set()
            self.Ch = set()
            self.MB = set()
            self.Pastar = set()
            self.verbose = False
            self.popa = set()

        def _oracle_initialization(self):

            adj = self.observed_data  

            self.learner = PMB_Learner(
                adj,
                ci_type="D_sep",
                max_K=self.max_K,
            )

            self.Nodes_dict = self.learner.Nodes_dict
            self.treatment_node = self.Nodes_dict[self.treatment_name]
            self.outcome_node = self.Nodes_dict[self.outcome_name]
            res = self.learner.Pastar_learner(self.treatment_name)
            pastar_names = res["Pastar"]         
            self.Pastar = {self.Nodes_dict[name] for name in pastar_names}
            self.mixgraph = res["PAG.MixGraph"]
            self.pag = self.mixgraph   
            self.Ch = self.mixgraph.get_children(self.treatment_node)
            self.Pa = self.mixgraph.get_parents(self.treatment_node)
            self.MB = self.learner.get_mb(self.treatment_node)

        def _data_initialization(self):
            adj = self.observed_data
            self.learner = PMB_Learner(
                adj,
                ci_type="Fisher_Z",
                max_K=self.max_K,
            )
            self.Nodes_dict = self.learner.Nodes_dict
            self.treatment_node = self.Nodes_dict[self.treatment_name]
            self.outcome_node = self.Nodes_dict[self.outcome_name]
            res = self.learner.Pastar_learner(self.treatment_name)
            pastar_names = res["Pastar"]         
            self.Pastar = {self.Nodes_dict[name] for name in pastar_names}
            self.mixgraph = res["PAG.MixGraph"]
            self.pag = self.mixgraph  
            self.Ch = self.mixgraph.get_children(self.treatment_node)
            self.Pa = self.mixgraph.get_parents(self.treatment_node)
            self.MB = self.learner.get_mb(self.treatment_node)

        def _ci(self, X, Y, Z):
            return self.CI_test(X, Y, Z)
        
        def _get_ci_num(self) -> int:
            return self.CI_test._ci_num

        def _check_not_identifiable(self):
            X = self.treatment_node
            adj_nodes = self.mixgraph.get_adj_nodes(X)
            for v in adj_nodes:
                if self.mixgraph.has_into_Edge(X, v):
                    return False
            return True

        def only_clear_pa_ch_or_safe_possible_pa_of_treatment(self, pag) -> bool:
            X = self.treatment_node
            assert X is not None and X.name in self.Nodes_dict


            Pa = set()
            Ch = set()
            safe_possible_pa = set()

            for v in pag.get_adj_nodes(X):
                if pag.has_into_Edge(v, X):
                    Pa.add(v)
                    continue
                if pag.has_into_Edge(X, v):
                    Ch.add(v)
                    continue

                if pag.has_circ_arrow_Edge(v, X):
                    neighbors_v = pag.get_adj_nodes(v)
                    for a in neighbors_v:
                        if a == X:
                            continue
                        for b in neighbors_v:
                            if b == X or b == a:
                                continue
                            if pag.is_collider(a, v, b):
                                return False
                    safe_possible_pa.add(v)
                    continue

                return False  

            self.pa_treatment = Pa
            self.safe_possible_pa_treatment = safe_possible_pa
            self.ch_treatment = Ch
            self.pa_safe_treatment = Pa | safe_possible_pa
            return True

        def _build_candidate_sets(self):
            X, Y = self.treatment_node, self.outcome_node
            Can_S = self.MB - {X, Y} - self.Ch
            raw_Z = self.MB - {X, Y} - self.Ch

            Can_Z = set()
            Pastar_names = [u.name for u in self.Pastar]

            for v in raw_Z:
                Z_ci = [
                    z for z in Pastar_names
                    if z != X.name and z != v.name
                ]
                indep, _ = self._ci(X.name, v.name, Z_ci)
                if indep:
                    Can_Z.add(v)

            Z_subsets = []
            L = list(Can_Z)

            max_r = min(len(L), self.max_K)
            for r in range(max_r + 1):
                Z_subsets.extend(itertools.combinations(L, r))

            return Can_S, Z_subsets

        def _Rule1_positive_effect(self, S, Z):
            X, Y = self.treatment_node, self.outcome_node
            ZN = [v.name for v in Z]
            if (
                S.name in ZN
                or X.name in ZN
                or Y.name in ZN
            ):
                return None

            if self.verbose:
                print(f"Testing Rule 1: S={S.name}, Z={ZN}")
            CI_one, P_one = self._ci(S.name, Y.name, ZN)
            if CI_one is False:
                CI_two, P_two = self._ci(S.name, Y.name, ZN + [X.name])
            else:
                CI_two, P_two = None, None

            return {
                "CI_one": CI_one,
                "P_one": P_one,
                "CI_two": CI_two,
                "P_two": P_two,
                "Z": ZN
            }

        def _Rule_one(self, Can_S, Z_subsets):
            for s in Can_S:
                for Z in Z_subsets:
                    res = self._Rule1_positive_effect(s, Z)

                    if res is None:
                        continue

                    if (
                        res["CI_one"] is False
                        and res["CI_two"] is True
                    ):
                        return res["Z"]

            return None

        def _Rule_two(self):
            if not self.only_clear_pa_ch_or_safe_possible_pa_of_treatment(self.pag):
                return None

            Pa_total = self.pa_safe_treatment
            indep, _ = self._ci(
                self.treatment_node.name,
                self.outcome_node.name,
                [v.name for v in Pa_total],
            )

            if indep:
                return {"Z": 0}
            else:
                return {"Z": [v.name for v in Pa_total]}

        def _Rule3_case1(self, Z):

            X, Y = self.treatment_node, self.outcome_node
            if (
                X.name in Z
                or Y.name in Z
            ):
                return False

            if self.verbose:
                print(f"Testing Rule 3 (Case 1): Z={Z}")

            CI, _ = self._ci(self.treatment_name, self.outcome_name, Z)
            return CI

        def _Rule3_case2(self, S, Z):

            X, Y = self.treatment_node, self.outcome_node
            if (
                S.name in Z
                or X.name in Z
                or Y.name in Z
            ):
                return False

            if self.verbose:
                print(f"Testing Rule 3 (Case 2): S={S.name}, Z={Z}")

            CI_sx, _ = self._ci(S.name, self.treatment_name, Z)
            if CI_sx:
                return False

            CI_sy, _ = self._ci(S.name, self.outcome_name, Z)
            return CI_sy

        def _Rule_three(self, Can_S, Z_subsets):
            for Z in Z_subsets:
                Z_names = [v.name for v in Z]

                if self._Rule3_case1(Z_names):
                    return 0
            for S in Can_S:
                for Z in Z_subsets:
                    Z_names = [v.name for v in Z]

                    if self._Rule3_case2(S, Z_names):
                        return 0

            return None

        def run(self):
            start = time.time()
            if hasattr(self.CI_test, "_ci_num"):
                self.CI_test._ci_num = 0
            if self.ci_type == "D_sep":
                self._oracle_initialization()
            else:
                self._data_initialization()

            if self._check_not_identifiable():
                return None, time.time() - start, self._get_ci_num()
            

            Can_S, Z_subsets = self._build_candidate_sets()

            Z_r1 = self._Rule_one(Can_S, Z_subsets)
            if Z_r1 is not None:
                return Z_r1, time.time() - start, self._get_ci_num()

            r2 = self._Rule_two()
            if r2 is not None:
                if r2["Z"] == 0:
                    return 0, time.time() - start, self._get_ci_num()
                return r2["Z"], time.time() - start, self._get_ci_num()

            r3 = self._Rule_three(Can_S, Z_subsets)
            if r3 == 0:
                return 0, time.time() - start, self._get_ci_num()

            return None, time.time() - start, self._get_ci_num()


def alg_LCS(data, Tr_X, Out_Y, max_k=3, CI_type="Fisher_Z", alpha=0.01):
    model = LCS(
        continuous_data=data,
        treatment=Tr_X,
        outcome=Out_Y,
        max_K=max_k,
        ci_type=CI_type,
        alpha=alpha,
    )

    AS, _, CI_num = model.run()

    VASs = AS if AS in (0, None) else [{"VAS": sorted(AS)}]

    effects = []
    if VASs not in (0, None):
        for vas in VASs:
            eff, _ = est_reg_con(Tr=Tr_X, Y=Out_Y, adjset=vas["VAS"], dataset=data)
            effects.append(eff)

    return {"VASs": VASs, "ATE": np.mean(effects) if effects else None, "CI_num": CI_num}


