from collections import defaultdict
from typing import Any, List, Tuple, Optional

import networkx as nx
import numpy as np
import pandas as pd
from causallearn.utils.cit import CIT
from dodiscover.toporder._base import SteinMixin, CAMPruning
from dodiscover.toporder.utils import kernel_width
from numpy._typing import NDArray
from pygam import LinearGAM
from scipy.stats import mannwhitneyu, ttest_ind, ttest_1samp
from sklearn.kernel_ridge import KernelRidge
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import KFold
from xgboost import XGBRegressor

from causal_discovery.base_scam import BaseSCAM, _all_subsets


class SCAMUV(BaseSCAM, SteinMixin, CAMPruning):

    def __init__(self,
                 alpha_confounded_leaf: float,
                 alpha_orientation: float,
                 alpha_separations: float,
                 alpha_ridge: float = 0.01,
                 regression: str = 'kernel_ridge',
                 eta_g: float = 0.001,
                 eta_h: float = 0.001,
                 var_eps: float = 1e-5,
                 cv: int = 5,
                 n_splines: int = 10,
                 splines_degree: int = 3,
                 alpha_pruning: float = 0.001,
                 prune_kci: bool = False,
                 use_cache: bool = True,
                 verbose: bool = False):
        super().__init__(verbose=verbose)
        self.alpha_confounded_leaf = alpha_confounded_leaf
        self.alpha_orientation = alpha_orientation
        self.alpha_separations = alpha_separations
        self.alpha_ridge = alpha_ridge
        self.regression = regression
        self.eta_g = eta_g
        self.eta_h = eta_h
        self.score_eps = var_eps
        self.cv = cv
        self.use_cache = use_cache
        self.cache = defaultdict(lambda: None)
        self.ordered_cache = set([])
        self.n_splines = n_splines
        self.degree = splines_degree
        self.alpha = alpha_pruning  # pruning method expects threshold to be called just 'alpha'
        self.prune_kci = prune_kci

    def fit(self, data: pd.DataFrame) -> nx.DiGraph:
        super().fit(data)
        print('SCAM order: ', self.order)
        self._prune_directed_edges()
        self._prune_bidirected_edges()
        return self.result_graph

    def _prune_directed_edges(self):
        # Pruning step
        for node in self.result_graph.nodes:
            potential_parents = [p for p in self.result_graph.predecessors(node) if
                                 not self.result_graph.has_edge(node, p)]
            if potential_parents:
                empty_prior_knowledge = nx.DiGraph()
                empty_prior_knowledge.add_nodes_from(self.result_graph.nodes)
                parents = self._variable_selection(self.data[potential_parents].to_numpy(),
                                                   self.data[node].to_numpy(),
                                                   potential_parents,
                                                   node,
                                                   empty_prior_knowledge
                                                   )
                for non_parent in set(potential_parents) - set(parents):
                    self.result_graph.remove_edge(non_parent, node)

    def _prune_bidirected_edges(self):
        if self.prune_kci:
            cit = CIT(method='kci', data=self.data.to_numpy())
            idx = {n: i for i, n in enumerate(self.data.keys())}
        for i, node in enumerate(self.data.keys()):
            for j, second_node in enumerate(self.data.keys()):
                if self.result_graph.has_edge(node, second_node) and self.result_graph.has_edge(second_node, node):
                    neighbourhood = set(self.result_graph.successors(node)).union(set(self.result_graph.predecessors(
                        node
                    )
                    )
                    ) - {node, second_node}
                    if self.prune_edge(node, second_node, neighbourhood):
                        self.result_graph.remove_edge(node, second_node)
                        self.result_graph.remove_edge(second_node, node)
                        break

    def get_regression(self, num_nodes: int):
        if self.regression == 'kernel_ridge':
            return KernelRidge(kernel='rbf', gamma=0.01, alpha=self.alpha_ridge)
        elif self.regression == 'gam':
            return LinearGAM()
        elif self.regression == 'xgboost':
            return XGBRegressor()
        elif self.regression == 'linear':
            return LinearRegression()
        else:
            raise NotImplementedError(self.regression)

    def _get_delta(self, node: Any, current_nodes: List[Any]) -> NDArray:
        current_nodes = list(np.sort(current_nodes))  # Sort for more cache hits
        _, score_vector, _, _ = self._get_score_and_helpers(current_nodes)
        score = score_vector[:, current_nodes.index(node)]
        data = self.data[current_nodes]
        predictors = data.loc[:, data.columns != node].to_numpy()
        node_residuals = self._cv_predict(predictors, data[node].to_numpy())

        score_residual = self._cv_predict(np.expand_dims(node_residuals, -1), score)
        return np.abs(score_residual) ** 2  # / (np.abs(score) + self.score_eps)

    def get_unconfounded_leaf(self, relevant_nodes: List[Any], current_nodes: List[Any]) -> Optional[Any]:
        deltas = {node: self._get_delta(node, current_nodes) for node in relevant_nodes}
        # ref_deltas = {node: np.concatenate([d for (n, d) in deltas.items() if n != node]) for node in deltas.keys()}
        # p_vals_vs_all = {node: mannwhitneyu(deltas[node], ref_deltas[node], alternative='less').pvalue
        #                 for node in deltas.keys()}
        # current_leaf = min(p_vals_vs_all, key=lambda n: p_vals_vs_all[n])
        current_leaf = min(deltas, key=lambda n: np.mean(deltas[n]))
        # p_vals_vs_all.pop(current_leaf)
        current_delta = deltas.pop(current_leaf)
        # reference_delta = deltas[min(p_vals_vs_all, key=lambda n: p_vals_vs_all[n])]
        reference_delta = deltas[min(deltas, key=lambda n: np.mean(deltas[n]))]
        p = ttest_ind(current_delta, reference_delta, alternative='less', equal_var=False).pvalue
        return current_leaf, p < self.alpha_confounded_leaf
        # if p < self.alpha_leaf:
        #    return current_leaf
        # else:
        #    return None

    def orient_edge(self, node, second_node, neighbourhood_node, neighbourhood_second):
        first_deltas = {frozenset(subset): self._get_delta(node, subset + [node, second_node]) for subset in
                        _all_subsets(neighbourhood_node - {second_node})}
        second_deltas = {frozenset(subset): self._get_delta(second_node, subset + [node, second_node]) for subset in
                         _all_subsets(neighbourhood_second - {node})}
        first_min = min(first_deltas, key=lambda n: np.mean(first_deltas[n]))
        second_min = min(second_deltas, key=lambda n: np.mean(second_deltas[n]))
        p_less = ttest_ind(first_deltas[first_min], second_deltas[second_min], alternative='less', equal_var=False).pvalue
        if p_less < self.alpha_orientation:
            return '<-'
        p_greater = ttest_ind(first_deltas[first_min], second_deltas[second_min], alternative='greater', equal_var=False).pvalue
        if p_greater < self.alpha_orientation:
            return '->'
        return '-'

    def are_connected(self, first_node: Any, second_node: Any, current_nodes: List[Any]) -> bool:
        current_nodes = list(np.sort(current_nodes))  # Sort for more cache hits
        X, score, kernel, s = self._get_score_and_helpers(current_nodes)
        first_idx = current_nodes.index(first_node)
        second_idx = current_nodes.index(second_node)
        first_col = self._hessian_col(X, score, first_idx, self.eta_h, kernel, s)[:, second_idx]
        second_col = self._hessian_col(X, score, second_idx, self.eta_h, kernel, s)[:, first_idx]
        return ttest_1samp(np.concatenate((first_col, second_col)), 0).pvalue < self.alpha_separations # why not a das approach
        # TODO check if better if we use also second col
        # return ttest_1samp(first_col, 0).pvalue < self.alpha_separations
        # return mannwhitneyu(np.abs(np.concatenate((first_col, second_col))), 0).pvalue < self.alpha_separations

    def prune_edge(self, first_node, second_node, neighbourhood) -> bool:
        deltas = []
        for subset in _all_subsets(neighbourhood):
            current_nodes = subset + [first_node, second_node]
            X, score, kernel, s = self._get_score_and_helpers(current_nodes)
            first_idx = current_nodes.index(first_node)
            second_idx = current_nodes.index(second_node)
            first_col = self._hessian_col(X, score, first_idx, self.eta_h, kernel, s)[:, second_idx]
            second_col = self._hessian_col(X, score, second_idx, self.eta_h, kernel, s)[:, first_idx]
            deltas.append(np.concatenate((first_col, second_col)))
        delta_min = min(deltas, key=lambda d: np.mean(np.abs(d)))
        return ttest_1samp(delta_min, 0).pvalue > self.alpha_separations

    def _cache_key(self, subset: List[Any]) -> str:
        return '.'.join(map(str, subset))

    def _get_score_and_helpers(self, subset: List[Any]) -> Tuple[NDArray, NDArray, NDArray, float]:
        key = self._cache_key(subset)
        if self.cache[key] is None:
            X = self.data[subset].to_numpy()
            _, d = X.shape
            s = kernel_width(X)
            kernel = self._evaluate_kernel(X, s=s)
            nablaK = self._evaluate_nablaK(kernel, X, s)
            score = self.score(X, self.eta_g, kernel, nablaK)
            if self.use_cache:
                self.cache[key] = X, score, kernel, s
            else:
                return X, score, kernel, s
        return self.cache[key]

    def _cv_predict(self, X: NDArray, y: NDArray) -> NDArray:
        if self.cv == 1:
            return self.get_regression(X.shape[1]).fit(X, y).predict(X)
        predictions = []
        for train_index, test_index in KFold(n_splits=self.cv, shuffle=False).split(X, y):
            reg = self.get_regression(X.shape[1]).fit(X[train_index, :], y[train_index])
            predictions.append(y[test_index] - reg.predict(X[test_index, :]))
        return np.concatenate(predictions, axis=0)
