from __future__ import annotations
from typing import Sequence, Any, Optional, Tuple, Mapping
import pandas as pd
import numpy as np
from scipy import sparse
from models import model_interface
from data import data_interface
import numba
from recourse.recourse_interface import RecourseInterface


#TODO: check distribution of weights in the graph

@numba.jit(nopython=True)
def _get_edge_weight(distance, weight_bias):
    """Converts pairwise distances to edge weights in the FACE graph.

    Args:
        distance: a float or a numpy array. (l2 distance)
        weight_bias: a float.

    Returns:
        a float or a numpy array with the same shape as `distance`."""
    return -np.log(distance) + weight_bias


class FACE():
    """The FACE recourse method.

    Modified from https://github.com/jakeval/MRMC/blob/main/recourse_methods/face_method.py

    This method is based on the Feasible and Actionable Counterfactual
    Explanations method described here: https://arxiv.org/abs/1909.09369.

    This class is partly based on the author's repo at
    https://github.com/RafaelPo/face.

    Attributes:
        dataset: The dataset to provide recourse over.
        adapter: The adapter to convert data into an embedded space.
        model: The model to provide recourse for.
        confidence_threshold: The minimum model confidence required for any
            counterfactual generated.
        graph: The graph to use with dijkstra's search. It is a sparse matrix
            encoding pairwise edge weights.
        candidate_indices: The positional 0-indexed indices of each potential
            counterfactual in thedataset.
        distance_threshold: The maximum edge length allowed in the graph. This
            should match the same distance_threshold used to generate the
            graph.
        k_directions: The number of recourse paths to generate.
        counterfactual_mode: If True, recourse directions are generated from
            the counterfactuals generated by FACE. Otherwise they are generated
            by taking the first step in the FACE path as the direction.
        weight_bias: A bias term to add when generating edge weights. This is
            necessary when using a large distance_threshold.
    """

    def __init__(
        self,
        data_inter: data_interface.DataInterface,
        model: model_interface.ModelInterface,
        k_directions: int,
        distance_threshold: float = 0.75,
        confidence_threshold: Optional[float] = None,
        graph_filepath: Optional[str] = None,
        counterfactual_mode: bool = True,
        weight_bias: float = 0,
        use_train_data: bool = True,
        graph_sample_size: Optional[int] = None,
        max_iterations = 50
    ):
        """Creates a new FACE object.

        If `graph_filepath` is not provided, the user should create the graph
        using FACE.generate_graph().
        """

        if use_train_data:
            self._features_data, _, self._labels_data, _ = data_inter.get_train_test_split()
        else:
            _, self._features_data, _, self._labels_data = data_inter.get_train_test_split()
        
        self._features_data = self._features_data.copy().reset_index(drop=True)
        self._labels_data = self._labels_data.copy().reset_index(drop=True)
        if graph_sample_size:
            print(graph_sample_size)
            self._features_data = self._features_data.head(graph_sample_size)
            self._labels_data = self._labels_data.head(graph_sample_size)
        self.model = model
        self.confidence_threshold = confidence_threshold
        try:
            self.graph = sparse.load_npz(graph_filepath)
        except Exception:
            self.graph = None
        self.candidate_indices = None
        self.distance_threshold = distance_threshold
        self.k_directions = k_directions
        self.counterfactual_mode = counterfactual_mode
        self.weight_bias = weight_bias
        self._data_interface = data_inter
        self._mutable_mask = [0 if ele in self._data_interface.immutable_features 
               or ele in self._data_interface.get_processed_immutable_feats() else 1 
               for ele in self._features_data.columns.values]
        self._max_iterations = max_iterations

    def generate_graph(
        self,
        filepath_to_save_to: Optional[str] = None,
        config: Mapping[str, Any] = None
    ) -> FACE:
        """Generates and saves an epsilon-graph.

        If `filepath_to_save_to` is provided, the graph is saved to that
        path as an .npz file. All directories along the path must already
        exist."""
        if self.graph:
            return self.graph
        if self.distance_threshold >= np.exp(self.weight_bias):
            raise RuntimeError(
                "Tried to generate a graph with negative edge weights. This ",
                "This happens if the distance threshold is equal to or ",
                "greater than exp(weight_bias). Try increasing weight_bias.",
            )
        data = self._features_data.values
        
        graph_weights = FACE._get_e_graph_weights(
            data,
            self.distance_threshold,
            weight_bias=self.weight_bias,
        )
        sparse_graph_weights = sparse.csr_array(graph_weights)
        
        if filepath_to_save_to is not None:
            sparse.save_npz(filepath_to_save_to, sparse_graph_weights)
            """filepath_without_extension = ".".join(
                filepath_to_save_to.split(".")[:-1]
            )
            config_name = filepath_without_extension + "_config.json"
            with open(config_name, "w") as f:
                config["elapsed_seconds"] = elapsed_time
                json.dump(config, f)"""
        self.graph = sparse_graph_weights

    def fit(self) -> FACE:
        """Fits FACE by finding candidate counterfactuals in the dataset."""
        positive_data = self._features_data.loc[self._labels_data[self._labels_data == 1].index]
        if self.confidence_threshold:
            probs = self.model.predict_proba(
                positive_data.values, pos_label_only=True)
            positive_confident_df = pd.Series(
                probs.flatten(), index=positive_data.index)
            positive_confident_df = positive_confident_df[positive_confident_df >=
                                                          self.confidence_threshold]
        if len(positive_confident_df) == 0:
            raise ValueError(
                "Dataset is empty after excluding negative outcome examples."
            )
        #self.candidate_indices = self._features_data.loc[positive_confident_df.index]
        self.candidate_indices = np.array(positive_confident_df.index)
        
        return self

    def generate_paths(
        self, poi: pd.DataFrame, num_paths: int
    ) -> Sequence[pd.DataFrame]:
        """Generates FACE paths for the Point of Interest (POI)."""
        distances, predecessors = self._dijkstra_search(poi, self.graph)
        target_indices = FACE._get_k_best_candidate_indices(
            distances, self.candidate_indices, num_paths
        )

        paths = []
        data = self._features_data.copy()
        for target_index in target_indices:
            path = []
            point_index = target_index
            while point_index != -1:
                path = [data.iloc[point_index].to_frame().T] + path
                point_index = predecessors[point_index]
                if len(path) > self._max_iterations:
                    break
            if len(path) > self._max_iterations:
                    path = [poi]
            else:
                path = [poi] + path
            paths.append(path)
        
        return paths

    def _get_k_recourse_directions(
        self, poi: pd.DataFrame, k_directions: int) -> pd.DataFrame:
        """Generates different recourse directions for the poi for each of the
        k_directions.

        If all k recourse directions are available, returns a dataframe with k
        rows. If no recourse is available, returns an empty dataframe.

        Args:
            poi: The Point of Interest (POI) to find recourse directions for.
            k_directionts: The number of recourse directions to generate.

        Returns:
            A DataFrame containing recourse directions for the POI."""
        paths = self.generate_paths(poi, k_directions)
        if len(paths) == 0:
            return pd.DataFrame(columns=poi.index)
        recourse_points = []  # These can be counterfactuals or steps in a path
        for path in paths:
            if self.counterfactual_mode:
                recourse_points.append(path.iloc[-1].values)
            else:
                recourse_points.append(path.iloc[1].values)
        recourse_points = np.vstack(recourse_points)
        directions = recourse_points - poi.values
        return pd.DataFrame(data=directions, columns=poi.index)

    def get_all_recourse_instructions(
        self, poi: pd.DataFrame
    ) -> Sequence[Optional[Any]]:
        """Generates different recourse instructions for the poi for each of
        the k_directions.

        Whereas recourse directions are vectors in embedded space,
        instructions are human-readable guides for how to follow those
        directions in the original data space.

        Args:
            poi: The Point of Interest (POI) to find recourse instructions for.

        Returns:
            A Sequence of k recourse instructions for the POI. Elements may be
            None if there is no recourse available."""
        
        recourse_directions = self._get_k_recourse_directions(poi, self.k_directions)
        instructions = []
        for i in range(recourse_directions.shape[0]):
            instructions.append(recourse_directions.iloc[i])
        if len(instructions) < self.k_directions:
            instructions += [None] * (self.k_directions - len(instructions))
        return instructions

    def get_kth_recourse_instructions(
        self, poi: pd.DataFrame, dir_index: int
    ) -> Optional[Any]:
        """Generates a single set of recourse instructions for the kth
        direction.

        Args:
            poi: The Point of Interest (POI) to get the kth recourse
            instruction for.
            dir_index: Which of the k sets of instructions to return. FACE
                ignores this argument.

        Returns:
            Instructions for the POI to achieve the recourse. Returns None
            if no recourse is possible."""
        recourse_directions = self._get_k_recourse_directions(poi, 1)
        if len(recourse_directions) == 0:
            return None
        return recourse_directions

    @staticmethod
    @numba.jit(nopython=True)
    def _get_e_graph_weights(
        embedded_data: np.ndarray, epsilon: float, weight_bias: float = 0
    ) -> np.ndarray:
        """Calculates a epsilon-weighted adjacency matrix from the data.

        This function is based on the get_weights_e() function from the
        original FACE authors' repo. (https://github.com/RafaelPo/face).

        This adjacency matrix is the FACE graph. Edges of distance greater
        than epsilon are zero-weighted. All other edges are weighted according
        to a weighting function (typically -log(x)).

        Args:
            embedded_data: The data to calculate a graph over.
            epsilon: Edges greater than epsilon have weight zero.
            weight_bias: This value is added to all non-zero edge weights. It
                is useful for avoiding negative values when epsilon is large.

        Returns:
            An adjacency matrix where the i,jth entry is the edge weight
            between points i and j.
        """
        n_samples = embedded_data.shape[0]
        kernel = np.zeros((n_samples, n_samples))
        for i in range(n_samples):
            for j in range(i):
                pairwise_distance = np.linalg.norm(
                    embedded_data[i] - embedded_data[j]
                )
                if (pairwise_distance <= epsilon) and pairwise_distance > 0:
                    kernel[i, j] = _get_edge_weight(
                        pairwise_distance, weight_bias
                    )
                    kernel[j, i] = kernel[i, j]
        return kernel

    def _dijkstra_search(
        self, poi: pd.DataFrame, graph: sparse.csr_array
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Perform dijkstra's search over the FACE graph.

        In addition to the distance array described below, it also returns a
        predecessors array encoding the shortest paths found by dijkstra's
        algorithm. The ith entry of the predecessors array is the index of its
        direct predecessor along the shortest path from the POI to i.

        For example, if the path to point 3 is POI->0->1->2, then
        predecessors[2] = 1.

        If there is no path between the POI and i, then
        predecessors[i] = -9999.

        If the ith entry is preceeded by the POI itself, then
        predecessors[i] = -1.

        Args:
            poi: The Point of Interest to start the search from.
            graph: The weighted adjacency matrix to search over.

        Returns:
            A distance array where the ith entry is the length of the shortest
                path between datapoint i and the POI.
            A predecessors array where the ith entry is the index of the point
                preceeding point i in the shortest path from the POI to i."""
        # temporarily add the POI to the graph so we can use it in graph search
        graph, new_index = self._add_point_to_graph(poi, graph)
        distances, predecessors = sparse.csgraph.dijkstra(
            graph, indices=new_index, return_predecessors=True
        )
        # remove the temporarily added point from the results
        distances = np.hstack(
            [distances[:new_index], distances[new_index + 1 :]]
        )
        predecessors = np.hstack(
            [predecessors[:new_index], predecessors[new_index + 1 :]]
        )
        predecessors = np.where(predecessors == new_index, -1, predecessors)
        return distances, predecessors

    def _add_point_to_graph(
        self, poi: pd.DataFrame, graph: sparse.csr_array
    ) -> Tuple[sparse.csr_array, int]:
        """Creates a graph identical to its input but with the addition of the
        POI.

        It calculates the weighted distances between the POI and the other
        points in the graph and adds these distances to the graph by
        concatenating to the distance matrix.

        Returns:
            A new graph including the POI and the index at which the POI
            appears in the graph."""
        immutable_mask = self._features_data.copy()
        immutable_mask = (immutable_mask[self._data_interface.get_processed_immutable_feats()]-
                                       poi[self._data_interface.get_processed_immutable_feats()].values)
        immutable_mask.loc[~(immutable_mask==0).all(axis=1)] = np.inf
        features_data_masked = self._features_data.copy()
        features_data_masked[self._data_interface.get_processed_immutable_feats()] = immutable_mask
        
        weight_vector = FACE._calculate_weight_vector(
            poi.values, features_data_masked.values, self.distance_threshold, self.weight_bias
        )
        
        return FACE._add_edges_to_graph(weight_vector, graph)

    @staticmethod
    def _calculate_weight_vector(
        poi: np.ndarray,
        data: np.ndarray,
        distance_threshold: float,
        weight_bias: float,
    ) -> np.ndarray:
        """Calculates a weight vector where the ith entry is the edge weight
        between the ith datapoint and the point of interest (POI).

        Args:
            poi: A vector of shape (D).
            data: A dataset of shape (N, D).
            distance_threshold: The maximum distance across which an edge will
                be created.
            weight_bias: The bias term to use when calculating edge weight.

        Returns:
            A weight vector where the ith entry is the edge weight bewteen the
            ith datapoint and the poi."""
        distances = np.linalg.norm(data - poi, axis=1)
        # A mask for values we don't want to calculate weights on
        exclude_mask = (distances == 0) | (distances > distance_threshold)

        weights = distances.copy()

        # excluded values are 0
        weights[exclude_mask] = 0

        # all other values get the appropriate weight for the POI
        weights[~exclude_mask] = _get_edge_weight(
            distances[~exclude_mask], weight_bias
        )
        return weights

    @staticmethod
    def _add_edges_to_graph(
        weight_vector: np.ndarray, graph: sparse.csr_array
    ) -> Tuple[sparse.csr_array, int]:
        """Creates a graph identical to its input but with the addition of
        weighted edges linking to a newly-added point.

        Args:
            weight_vector: The densely-encoded edge weights to add to the
                graph. Where no edge is added, the edge weight in the vector is
                0.
            graph: The (N, N) adjacency matrix representing a graph's weighted
                edges.

        Returns:
            An (N+1, N+1) adjacency matrix including weighted edges to a newly
            added point. It also returns the index at which the point
            corresponding to the new edge weights was added."""
        row_update = sparse.csr_array(weight_vector)
        # The column update has size N+1. The extra index includes the edge
        # weight between the newly added point and itself (which is 0).
        col_update = sparse.csr_array(
            np.hstack([weight_vector, [0]])[None, :].T
        )
        graph = sparse.vstack([graph, row_update])
        graph = sparse.hstack([graph, col_update])
        return graph, graph.shape[0] - 1

    @staticmethod
    def _get_k_best_candidate_indices(
        distances: np.ndarray,
        candidate_indices: np.ndarray,
        k_candidates: int,
    ) -> np.ndarray:
        """Attempts to return the indices of the k best candidate
        counterfactuals.

        "Best" means it minimizes the edge weights along the shortest path
        between it and the POI.

        Because the inputs and outputs of this function are numpy arrays, the
        indices it returns don't correspond to pandas indices of the source
        data (if any exist).

        If fewer than k viable candidates exist, it returns only as many as
        possible. If no candidates exist, it returns an empty list.

        If two candidates have the same distance, the candidate with lower
        index is preferred.

        Args:
            distances: A distance array where the ith entry is the length of
                the shortest path between datapoint i and the POI.
            candidate_indices: The indices of potential counterfactuals in the
                dataset.
            k_candidates: The number of candidate indices to return.

        Returns:
            A list of at most k indices.
        """
        sorted_candidate_indices = candidate_indices[
            np.argsort(distances[candidate_indices], kind="stable").astype(int)
        ]
        candidate_indices = []
        for target_index in sorted_candidate_indices:
            if len(candidate_indices) == k_candidates:
                break
            if np.isinf(distances[target_index]):  # there is no path to here
                continue
            candidate_indices.append(target_index)
        return np.array(candidate_indices)

class FACERecourse(RecourseInterface):
    """
    TODO: docstring
    """

    def __init__(self, model: model_interface.ModelInterface, data_interface: data_interface.DataInterface,
                k_directions: int,
                distance_threshold: float = 0.75,
                confidence_threshold: Optional[float] = None,
                graph_filepath: Optional[str] = None,
                counterfactual_mode: bool = True,
                weight_bias: float = 0,
                use_train_data: bool = True,
                graph_config: Mapping[str, Any] = None,
                graph_sample_size: Optional[int] = None,
                max_iterations = 50 
            ) -> None:

        self.FACE_instance = FACE(data_interface, model, k_directions, distance_threshold,
                confidence_threshold, graph_filepath, counterfactual_mode,weight_bias,
                use_train_data, graph_sample_size,max_iterations=max_iterations)
        self._k_directions = k_directions
        self.FACE_instance.fit()
        self.FACE_instance.generate_graph(graph_filepath, graph_config)

    def get_counterfactuals(self, poi: pd.DataFrame) -> Sequence:
        # TODO: write docstring.
        cfs = []
        paths = self.get_paths(poi)
        for p in paths:
            cfs.append(p[-1])
        return cfs

    def get_paths(self, poi: pd.DataFrame) -> Sequence:
        # TODO: write docstring.
        return self.FACE_instance.generate_paths(poi,self._k_directions)
    
    def get_counterfactuals_from_paths(self, paths: Sequence) -> Sequence:
        # TODO: write docstring.
        cfs = []
        for p in paths:
            cfs.append(p[-1])
        return cfs
    
