from typing import Any, List, Optional

import networkx as nx

from benchmark.data.generate_data import marginalize
from causal_discovery.base_scam import BaseSCAM, _all_subsets


class OracleSCAM(BaseSCAM):

    def __init__(self, ground_truth: nx.DiGraph, verbose: bool = False):
        super().__init__(verbose)
        self.ground_truth = ground_truth.copy()
        # self.scam = SCAMUV(0.01, regression='gam', cv=2)

    # def fit(self, data: pd.DataFrame) -> nx.DiGraph:
    #    self.scam.data = data
    #    return super().fit(data)

    def get_unconfounded_leaf(self, relevant_nodes: List[Any], current_nodes: List[Any]) -> Optional[Any]:
        marginal_graph = marginalize(self.ground_truth.copy(),
                                     current_nodes,
                                     indicate_confounding=True,
                                     indicate_unobs_direct_paths=True
                                     )
        for node in marginal_graph:
            if not list(marginal_graph.successors(node)):
                return node, True
        return current_nodes[0], False
        # return self.scam.get_unconfounded_leaf(relevant_nodes, current_nodes)

    def orient_edge(self, node, second_node, neighbourhood_node, neighbourhood_second):
        for subset in _all_subsets(neighbourhood_node - {second_node}):
            marginal_graph = marginalize(self.ground_truth.copy(),
                                         subset + [node, second_node],
                                         indicate_confounding=True,
                                         indicate_unobs_direct_paths=True
                                         )
            if not list(marginal_graph.successors(node)) and marginal_graph.has_edge(second_node, node):
                return '<-'
        for subset in _all_subsets(neighbourhood_second - {node}):
            marginal_graph = marginalize(self.ground_truth.copy(),
                                         subset + [node, second_node],
                                         indicate_confounding=True,
                                         indicate_unobs_direct_paths=True
                                         )
            if marginal_graph.has_edge(node, second_node) and not list(marginal_graph.successors(second_node)):
                return '->'
        return '-'

    def are_connected(self, first_node: Any, second_node: Any, current_nodes: List[Any]) -> bool:
        return not nx.d_separated(self.ground_truth,
                                  {first_node},
                                  {second_node},
                                  set(current_nodes) - {first_node, second_node}
                                  )
        # return self.scam.are_connected(first_node, second_node, current_nodes)
