from abc import ABC, abstractmethod
from collections import defaultdict
from itertools import combinations
from typing import Any, List, Collection, Optional, Dict

import networkx as nx
import pandas as pd


def _all_subsets(s: Collection[Any]) -> List[Any]:
    for subset_size in range(len(s) + 1):
        for subset in combinations(s, subset_size):
            yield list(subset)


def _no_direct_edge(g: nx.DiGraph, node: Any, second_node: Any) -> bool:
    no_edge = not g.has_edge(node, second_node) and not g.has_edge(second_node, node)
    bidir = g.has_edge(node, second_node) and g.has_edge(second_node, node)
    return no_edge or bidir


class BaseSCAM(ABC):

    @abstractmethod
    def get_unconfounded_leaf(self, relevant_nodes: List[Any], current_nodes: List[Any]) -> Optional[Any]:
        raise NotImplementedError()

    @abstractmethod
    def are_connected(self, first_node: Any, second_node: Any, current_nodes: List[Any]) -> bool:
        raise NotImplementedError()

    @abstractmethod
    def orient_edge(self, node, second_node, neighbourhood_node, neighbourhood_second):
        raise NotImplementedError()

    def __init__(self, verbose: bool = True):
        self.verbose = verbose
        self.order: List[Any] = None
        self.remaining_nodes: List[Any] = None
        self.result_graph = nx.DiGraph()
        self.visited_nodes = set([])
        self.boundaries = {}
        self.data = None

    def fit(self, data: pd.DataFrame) -> nx.DiGraph:
        if self.verbose:
            print('Training: ', type(self))
        self.data = data
        self.order = []
        self.remaining_nodes = list(data.keys())
        self.result_graph = nx.DiGraph()
        self.result_graph.add_nodes_from(self.remaining_nodes)
        for node in data.keys():
            self.boundaries[node] = set(data.keys()) - {node}

        for i in range(len(self.remaining_nodes)):
            self._initial_prune_all_neighbourhoods(self.remaining_nodes)
            current_leaf, is_leaf = self.get_unconfounded_leaf(self.remaining_nodes, self.remaining_nodes)
            if is_leaf:
                # current_leaf can be safely removed
                self._remove_unconfounded_leaf(current_leaf)
                if len(self.remaining_nodes) == 1:
                    self.order = (self.order + self.remaining_nodes)[::-1]
                    return self.result_graph
            else:
                # TODO currently pick node with smallest boundary for speed. But lowest delta might be more robust
                # if boundaries:
                # node = min(self.boundaries, key=lambda k: len(self.boundaries[k]))
                node = current_leaf
                # else:
                #    node = self.remaining_nodes[0]
                self._remove_confounded_node(node)

                if len(self.remaining_nodes) == 1:
                    self.order = (self.order + self.remaining_nodes)[::-1]
                    return self.result_graph

        raise Exception("Somethings wrong")

    def _remove_unconfounded_leaf(self, node: Any):
        if self.verbose:
            print("Found unconfounded leaf ", node)
        self.order.append(node)
        self.remaining_nodes.remove(node)
        for parent in self.boundaries[node]:
            if self.verbose:
                print("Add ", parent, "->", node)
            self.result_graph.add_edge(parent, node)

    def _remove_confounded_node(self, node: Any, recursion_history: Optional[Dict[Any, int]] = None):
        if recursion_history is None:
            recursion_history = defaultdict(lambda: 0)
        else:
            recursion_history[node] = recursion_history[node] + 1
        if self.verbose:
            print("Explore confounded node ", node)

        self._orient_edges(node)

        # If node has direct children, recurse on them until one without is found
        for child in self.boundaries[node]:
            if (self.result_graph.has_edge(node, child)
                    and not self.result_graph.has_edge(child, node)
                    # prevents endless recursion if circle due to finite samples
                    and not recursion_history[child] > len(self.result_graph.nodes)):
                self._remove_confounded_node(child, recursion_history)
                return
        if self.verbose:
            print("Remove confounded leaf ", node)

        # Prune neighbourhoods edges
        # self._prune_node_neighbourhood(node) #TODO add post-pruning of bidirected edges to oracle
        # Remove node from remaining nodes
        self.order.append(node)
        self.remaining_nodes.remove(node)

    def _orient_edges(self, node: Any):
        if node in self.visited_nodes:
            return
        self.visited_nodes.add(node)

        for subset in _all_subsets(self.boundaries[node]):
            if len(subset) > 0:
                full_subset = subset + [node]

                # Prune neighbourhoods direct edges TODO maybe in one function with other neighbourhood pruning?
                for j, second_node in enumerate(subset):
                    if (not self.result_graph.has_edge(node, second_node)
                            and not self.result_graph.has_edge(second_node, node)):
                        if not self.are_connected(node, second_node, full_subset):
                            self.boundaries[node].discard(second_node)
                            self.boundaries[second_node].discard(node)

        for second_node in self.boundaries[node]:
            orientation = self.orient_edge(node, second_node, self.boundaries[node], self.boundaries[second_node])
            if orientation == '->':
                if self.verbose:
                    print("Add ", node, "->", second_node)
                self.result_graph.add_edge(node, second_node)
            elif orientation == '<-':
                if self.verbose:
                    print("Add ", second_node, "->", node)
                self.result_graph.add_edge(second_node, node)
            else:
                if self.verbose:
                    print("Add ", node, "<->", second_node)
                self.result_graph.add_edge(node, second_node)
                self.result_graph.add_edge(second_node, node)

    def _prune_node_neighbourhood(self, node: Any):
        if len(set(self.result_graph.successors(node)).intersection(set(self.result_graph.predecessors(node)))) < 1:
            return
        for subset in _all_subsets(set(self.remaining_nodes) - {node}):
            for i, first_node in enumerate(subset):
                for j, second_node in enumerate(subset):
                    if (i < j and first_node in self.boundaries[node] and second_node in self.boundaries[
                        node]):
                        if not self.are_connected(first_node, second_node, subset + [node]):
                            self.boundaries[first_node].discard(second_node)
                            self.boundaries[second_node].discard(first_node)

    def _initial_prune_all_neighbourhoods(self, current_nodes: List[Any]):
        for i, node in enumerate(current_nodes):
            for j, node_two in enumerate(current_nodes):
                if i < j and node_two in self.boundaries[node].copy():
                    if not self.are_connected(node, node_two, current_nodes):
                        self.boundaries[node].discard(node_two)
                        self.boundaries[node_two].discard(node)
        self.boundaries = {n: self.boundaries[n].intersection(current_nodes) for n in current_nodes}
