from collections import defaultdict
from itertools import combinations
from typing import Any, List, Collection

import networkx as nx
import pandas as pd
from causallearn.graph.GraphNode import GraphNode
from causallearn.utils.FAS import fas

from causal_discovery.scamuv import SCAMUV
from causal_discovery.score_independence import ScoreIndependence


def _all_subsets(s: Collection[Any]) -> List[Any]:
    for subset_size in range(len(s) + 1):
        for subset in combinations(s, subset_size):
            yield list(subset)


class OAFAS(SCAMUV):

    def __init__(self,
                 alpha: float,
                 regression: str = 'gam',
                 eta_g: float = 0.001,
                 eta_h: float = 0.001,
                 var_eps: float = 1e-5,
                 cv: int = 2,
                 verbose: bool = False):
        super().__init__(alpha, regression, eta_g, eta_h, var_eps, cv, verbose)
        self.alpha = alpha
        self.verbose = verbose
        self.boundaries = None
        self.visited_nodes = set([])

    def fit(self, data: pd.DataFrame) -> nx.DiGraph:
        self.data = data
        self.boundaries = {node: set([]) for node in data.keys()}
        cit = ScoreIndependence(data.to_numpy())
        result_skel = fas(data.to_numpy(),
                          [GraphNode(n) for n in data.keys()],
                          alpha=self.alpha,
                          independence_test_method=cit,
                          verbose=self.verbose
                          )[0]
        for edge in result_skel.get_graph_edges():
            self.boundaries[edge.get_node1().get_name()].add(edge.get_node2().get_name())
            self.boundaries[edge.get_node2().get_name()].add(edge.get_node1().get_name())

        self.result_graph = nx.DiGraph()
        self.result_graph.add_nodes_from(data.keys())
        for node in data.keys():
            self._orient_edges(node)

        return self.result_graph

    def _orient_edges(self, node: Any):
        potential_edge = defaultdict(lambda: defaultdict(lambda: False))
        for subset in _all_subsets(self.boundaries[node].copy()):
            if len(subset) > 0:
                full_subset = subset + [node]
                current_leaf = self.get_unconfounded_leaf(full_subset, full_subset)
                if current_leaf is not None:
                    for potential_parent in full_subset:
                        if potential_parent != current_leaf:
                            potential_edge[potential_parent][current_leaf] = True

        for second_node in self.boundaries[node]:
            if potential_edge[node][second_node] and not potential_edge[second_node][node]:
                if self.verbose:
                    print("Add ", node, "->", second_node)
                self.result_graph.add_edge(node, second_node)
            elif not potential_edge[node][second_node] and potential_edge[second_node][node]:
                if self.verbose:
                    print("Add ", second_node, "->", node)
                self.result_graph.add_edge(second_node, node)
            else:
                if self.verbose:
                    print("Add ", node, "<->", second_node)
                self.result_graph.add_edge(node, second_node)
                self.result_graph.add_edge(second_node, node)
