from typing import Dict, Optional, List, Set, Union
from graphical_models import DAG
import itertools as itr
from conditional_independence import CI_Tester, partial_correlation_test
from conditional_independence import InvarianceTester
from graphical_models import UndirectedEdge
from graphical_model_learning.utils.core_utils import powerset, iszero
import random
from graphical_model_learning.algorithms.undirected import threshold_ug, partial_correlation_threshold
from graphical_models import UndirectedGraph
import numpy as np
from tqdm import trange, tqdm
from math import factorial
from causaldag import partial_correlation_suffstat, partial_correlation_test, MemoizedCI_Tester
from causaldag import gauss_invariance_suffstat, gauss_invariance_test, MemoizedInvarianceTester
from causaldag import min_degree_alg_amat, gsp


def jci_gsp(all_data, alpha):
    nodenum = all_data[0].shape[1]
    num_of_contexts = len(all_data)
    combined_data = []
    for i, data in enumerate(all_data):
        samplenum, nodenum = data.shape
        curr_combined_data = np.zeros((samplenum, nodenum + num_of_contexts))
        curr_combined_data[:, num_of_contexts:] = data
        curr_combined_data[:, i] = 1
        combined_data.extend(curr_combined_data)
    combined_data = np.array(combined_data)
    nodes = set(range(combined_data.shape[1]))
    # Form sufficient statistics
    obs_suffstat = partial_correlation_suffstat(combined_data)
    # Create conditional independence tester
    ci_tester = MemoizedCI_Tester(partial_correlation_test, obs_suffstat, alpha=alpha)
    # Run JCI-GSPP
    setting_list = [dict(known_interventions=[]) for _ in range(len(all_data))]
    dag_est, meta_dag_obj_est = jci_gsp_func(setting_list, nodes, num_of_contexts, ci_tester, initial_undirected='threshold')
    dag_est = dag_est.to_amat()[0]
    params_est = {'dag': dag_est, 'meta_dag_obj_est': meta_dag_obj_est}
    return params_est


def jci_gsp_func(
        setting_list: List[Dict],
        nodes: set,
        num_of_contexts: int,
        combined_ci_tester: CI_Tester,
        depth: int = 4,
        nruns: int = 5,
        verbose: bool = False,
        initial_undirected: Optional[Union[str, UndirectedGraph]] = 'threshold',
):
    """
    TODO

    Parameters
    ----------
    TODO

    Examples
    --------
    TODO
    """
    # CREATE NEW NODES AND OTHER INPUT TO ALGORITHM
    context_nodes = set(list(range(num_of_contexts)))
    nodes = set(list(range(num_of_contexts, len(nodes))))
    # context_nodes = ['c%d' % i for i in range(len(setting_list))]
    context_adjacencies = set(itr.permutations(context_nodes, r=2))
    # known_iv_adjacencies = set.union(*(
    #     {('c%s' % i, node) for node in setting['known_interventions']} for i, setting in enumerate(setting_list)
    # ))
    known_iv_adjacencies = set()
    fixed_orders = set(itr.combinations(context_nodes, 2)) | set(itr.product(context_nodes, nodes))

    # === DO SMART INITIALIZATION
    if isinstance(initial_undirected, str):
        if initial_undirected == 'threshold':
            initial_undirected = threshold_ug(set(nodes), combined_ci_tester)
        else:
            raise ValueError("initial_undirected must be one of 'threshold', or an UndirectedGraph")
    if initial_undirected:
        amat = initial_undirected.to_amat()
        initial_permutations = [min_degree_alg_amat(amat) for _ in range(nruns)]
    else:
        initial_permutations = [random.sample(list(nodes), len(nodes)) for _ in range(nruns)]

    # === RUN GSP ON FULL DAG
    est_meta_dag = gsp(
        nodes | set(context_nodes),
        combined_ci_tester,
        depth=depth,
        nruns=nruns,
        initial_permutations=initial_permutations,
        fixed_orders=fixed_orders,
        fixed_adjacencies=context_adjacencies | known_iv_adjacencies,
        verbose=verbose
    )
    # === PROCESS OUTPUT
    # learned_intervention_targets = {
    #     int(node[1:]): {child for child in est_meta_dag.children_of(node) if not isinstance(child, str)}
    #     for node in context_nodes
    # }
    # learned_intervention_targets = {
    #     int(node[1:]): {child for child in est_meta_dag.children_of(node) if not isinstance(child, str)}
    #     for node in context_nodes
    # }
    # learned_intervention_targets = [learned_intervention_targets[i] for i in range(len(setting_list))]
    est_dag = est_meta_dag.induced_subgraph(nodes)
    return est_dag, est_meta_dag