#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
@author: Haoyue
@file: verify_augmented_graph_issue.py
@time: 9/21/24 09:02
@desc: 
"""
from itertools import chain, combinations

import causallearn.utils.cit as cit
import numpy as np

from methods.MAG_tools import get_skeleton_and_sepsets, get_PAG_from_skeleton_and_sepsets


def powerset(iterable):
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    return list(chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)))


class Method(object):
    def __init__(self, all_data, CI_type='fisherz', alpha=0.05):
        self.all_data = all_data
        self.CI_type = CI_type
        self.alpha = alpha
        self.nodenum = all_data[0].shape[1]
        self.num_of_intervs= len(all_data) - 1
        self.CI_cache = dict()

    def is_cond_ind(self, x, y, Z, data):
        Z = set() if Z is None else set(Z)
        assert x != y and {x, y} & Z == set()
        x, y = min(x, y), max(x, y)
        if self.CI_type == 'fisherz':
            CI_test = cit.CIT(data, 'fisherz')
            pval = CI_test(x, y, Z)
            CI_result = not pval <= self.alpha
        elif self.CI_type == 'kci':
            CI_test = cit.CIT(data, 'kci', kernelX='Gaussian', kernelY='Gaussian',
                              kernelZ='Gaussian', est_width='median',
                              use_gp=False, approx=False)
            pval = CI_test(x, y, Z)
            CI_result = not pval <= self.alpha
        else:
            raise ValueError("Unknown ci type.")
        return CI_result

    def is_CI_in_domain_k(self, x, y, Z, k):
        assert 0 <= k <= self.num_of_intervs
        Z = set() if Z is None else set(Z)
        assert x != y and len({x, y} & set(Z)) == 0
        x, y = min(x, y), max(x, y)
        cachekey = (x, y, frozenset(Z), k)
        if cachekey not in self.CI_cache:
            data = self.all_data[k]
            CI_result = self.is_cond_ind(x, y, Z, data)
            self.CI_cache[cachekey] = CI_result
        return self.CI_cache[cachekey]

    def is_CI_in_observational_domain(self, x, y, Z):
        assert x != y and len({x, y} & set(Z)) == 0
        return self.is_CI_in_domain_k(x, y, Z, k=0)

    def is_conditional_distribution_invariant_in_domain_pk_and_p0(self, x, Z, k):
        assert 1 <= k <= self.num_of_intervs
        assert not x in Z
        obs_data = self.all_data[0]
        intervened_data = self.all_data[k]
        data = np.vstack([obs_data, intervened_data])
        zeta = np.hstack([[0] * len(obs_data), [1] * len(intervened_data)])
        data_with_zeta = np.hstack([data, zeta[:, np.newaxis]])
        y = data_with_zeta.shape[1] - 1
        cachekey = (x, y, frozenset(Z), (0, k))
        if cachekey not in self.CI_cache:
            CI_result = self.is_cond_ind(x, y, Z, data_with_zeta)
            self.CI_cache[cachekey] = CI_result
        return self.CI_cache[cachekey]

    def run_algo(self):
        # ======= step 1: FCI on X on observational domain only. mainly to get skeleton and sepsets for vstrucs.
        #         the orientation rules are FCI's R0-R4, with additionally: all ⚬-> to -> (since we know there're no latents)
        # TODO: add nodelist and CI_tester, and self.MAG_of_original_G.observed_nodes
        nodelist = list(range(self.nodenum))
        adjacencies_from_observational, Sepsets \
            = get_skeleton_and_sepsets(nodelist, self.is_CI_in_observational_domain)      # (x,y) with x<y
        self.pag_edges_from_observational = get_PAG_from_skeleton_and_sepsets(
            nodelist=nodelist,
            skeleton_edges=adjacencies_from_observational,
            sepsets=Sepsets,
            sure_no_latents=True)

        # ======= (the whole step 2 is technically just running FCI over I and X again; we decouple it to escape from bothering over Is)
        # ======= step 2.1: find the adjacencies between I and X, the (pseudo)-intervention targets of each domain.
        #         by "pseudo" we mean the adjacent Xs contains but not limited to the actual targets.
        adjacencies_between_I_and_X = set()
        Sepsets_in_Gtwin = {}
        for k, kprime in combinations(range(1, 1 + self.num_of_intervs), 2):
            Sepsets_in_Gtwin[(f'I{k}', f'I{kprime}')] = Sepsets_in_Gtwin[(f'I{kprime}', f'I{k}')] = set()  # ensure indices mutually indep.
        for k in range(1, 1 + self.num_of_intervs):
            for i in range(self.nodenum):
                # we search if there's a subset C s.t. p(Xi|XC) is invariant between domain p(0) and p(k)
                #   can show that such search on C can be restricted to the adjacencies of Xi in observational skeleton.
                i_adjacencies = {j for j in range(self.nodenum) if tuple(sorted((i, j))) in adjacencies_from_observational}
                found_invariant = False
                for C in powerset(i_adjacencies):
                    # to check whether p(Xi|XC) is invariant, we check whether Ik _||_ Xi | XC and other Ikprimes=0
                    #   if d-separation on twin G^I holds, by Markov property, the CI holds.
                    #   if not, will there also be unfaithfulness as in those among X?
                    #       the answer is no, as in this case [given XC* yields however dsep] => [Ik affects XC (so XC* cannot be given)]
                    # therefore, here we can safely use d-separation on twin G^I to check CI.
                    if self.is_conditional_distribution_invariant_in_domain_pk_and_p0(i, C, k):
                        found_invariant = True
                        Sepsets_in_Gtwin[(f'I{k}', f'X{i}')] = Sepsets_in_Gtwin[(f'X{i}', f'I{k}')] = \
                            {f'X{c}' for c in C} | {f'I{kprime}' for kprime in range(1, 1+self.num_of_intervs) if kprime != k}
                        break
                if not found_invariant:
                    adjacencies_between_I_and_X.add((f'I{k}', f'X{i}'))

        # ======= step 2.2: some additional adjacencies among X have to be added.
        #         though we know they are incorrect (as the correct ones are just those from step 1's observational domain),
        #            we have to add them to ensure the next step's FCI orientation rules are correct.
        #         basically it's because: an Xi_||_Xj | XC in p(0) may be dependent in p(k).
        #         [*]: Xi--Xj can be cut off, iff there exists a C, s.t. Xi_||_Xj | XC holds in all domains from p(0) to p(K).
        adjacencies_between_X_and_X = set()
        for x, y in set(combinations(range(self.nodenum), 2)):
            if (x, y) in adjacencies_from_observational:  # if adj in p(0), must be adj in all p(k). no need to check anymore.
                adjacencies_between_X_and_X.add((f'X{x}', f'X{y}'))
                continue
            # we can show that we only need to search sepset in Xi and Xj's adjacencies in observational domain (instead from the whole X).
            x_adjacencies_in_obsv = {j for j in range(self.nodenum) if tuple(sorted((x, j))) in adjacencies_from_observational}
            y_adjacencies_in_obsv = {j for j in range(self.nodenum) if tuple(sorted((y, j))) in adjacencies_from_observational}
            choose_Z_from = set(powerset(x_adjacencies_in_obsv)) | set(powerset(y_adjacencies_in_obsv))
            found_such_sepset = False
            for Z in choose_Z_from:
                if all(self.is_CI_in_domain_k(x, y, Z, k) for k in range(0, 1 + self.num_of_intervs)):
                    # note: the sepset found here may be different from that found in step 1.
                    found_such_sepset = True
                    Xz_and_all_Iks = {f'X{z}' for z in Z} | {f'I{k}' for k in range(1, 1 + self.num_of_intervs)}
                    Sepsets_in_Gtwin[(f'X{x}', f'X{y}')] = Sepsets_in_Gtwin[(f'X{y}', f'X{x}')] = Xz_and_all_Iks
                    break
            if not found_such_sepset:
                adjacencies_between_X_and_X.add((f'X{x}', f'X{y}'))

        # ======= step 2.3: apply FCI rules on those adjacencies and sepsets, and known directions.
        directed_edges_from_observational = {(f'X{x}', f'X{y}') for x, y in self.pag_edges_from_observational['->']}
        twin_nodelist = [f'X{i}' for i in range(self.nodenum)] + [f'I{i}' for i in range(1, 1 + self.num_of_intervs)]
        self.pag_edges_from_interventional = get_PAG_from_skeleton_and_sepsets(
            nodelist=twin_nodelist,
            skeleton_edges=adjacencies_between_X_and_X | adjacencies_between_I_and_X,
            sepsets=Sepsets_in_Gtwin,
            sure_direct_edges=directed_edges_from_observational | adjacencies_between_I_and_X)
        directed_edges_from_interventional = {
            (int(x[1:]), int(y[1:])) for x, y in self.pag_edges_from_interventional['->'] if x[0] == 'X' and y[0] == 'X'
        }
        # Convert directed edges into dag
        dag = np.zeros((self.nodenum, self.nodenum))
        for i, j in directed_edges_from_interventional:
            # Add edge i -> j
            dag[i, j] = 1
        params_est = {
            'dag': dag,
            'directed_edges_from_interventional': directed_edges_from_interventional,
            'pag_edges_from_observational': self.pag_edges_from_observational,
            'pag_edges_from_interventional': self.pag_edges_from_interventional
        }
        return params_est