"""
Modified version of https://github.com/joaopfonseca/game-of-recourse/blob/main/game_viz/recourse.py
New recourse methods added: Wachter et al. (2017), DiCE (Mothilal et al., 2020)
"""

import warnings
from typing import Union, Optional, Callable
import numpy as np
from base import BaseRecourse


class UstunRecourse(BaseRecourse):

    def __init__(
        self,
        model,
        n_features: int = None,
        threshold=0.5,
        categorical: Union[list, np.ndarray] = None,
        immutable: Union[list, np.ndarray] = None,
        step_direction: dict = None,
        y_desired: Union[int, str] = 1,
    ):
        super().__init__(
            model=model,
            threshold=threshold,
            categorical=categorical,
            immutable=immutable,
            step_direction=step_direction,
            y_desired=y_desired,
        )

        self.n_features = n_features

    def _counterfactual(self, agent, action_set):
        agent_original = agent.copy()

        # Do not change if the agent is over the threshold
        if self.model.predict_proba(agent.to_frame().T)[0, -1] >= self.threshold:
            return agent_original

        categorical_vals = agent_original[self.categorical].values
        agent = agent_original.drop(self.categorical).copy()

        intercept, coefficients, model = self._get_coefficients()

        # Get base vector
        base_vector = coefficients.copy().squeeze()
        n_features = (
            base_vector.shape[0] if self.n_features is None else self.n_features
        )

        is_usable = np.array(
            [
                action_set[col].step_direction in [np.sign(coeff), 0]
                and action_set[col].actionable
                for col, coeff in zip(agent.index, base_vector)
            ]
        )
        base_vector[~is_usable] = 0

        # Use features with highest contribution towards the threshold
        rejected_features = np.argsort(np.abs(base_vector))[:-n_features]
        base_vector[rejected_features] = 0

        base_vector = base_vector / np.linalg.norm(base_vector)
        multiplier = (-intercept - np.dot(agent.values, coefficients.T)) / np.dot(
            base_vector, coefficients.T
        )
        counterfactual = agent + multiplier * base_vector

        lb, ub = np.array(action_set.lb).flatten(), np.array(action_set.ub).flatten()

        #lb = lb[action_set.df.name.values != self.categorical]
        #ub = ub[action_set.df.name.values != self.categorical]

        # Check if base_vector adjustments are not generating invalid counterfactuals
        for i in range(agent.shape[0]):
            # Adjust vector according to features' bounds
            lb_valid = counterfactual >= lb
            ub_valid = counterfactual <= ub

            if lb_valid.all() and ub_valid.all():
                break

            if not lb_valid.all():
                # Fix values to its lower bound
                idx = np.where(~lb_valid)[0]
                agent[idx] = lb[idx]
                base_vector[idx] = 0

            if not ub_valid.all():
                # Fix values to its upper bound
                idx = np.where(~ub_valid)[0]
                agent[idx] = ub[idx]
                base_vector[idx] = 0

            if (base_vector == 0).all():
                # All max/min boundaries have been met.
                counterfactual = agent
            else:
                # Redefine counterfactual after adjusting the base vector
                base_vector = base_vector / np.linalg.norm(base_vector)
                multiplier = (
                    -intercept - np.dot(agent.values, coefficients.T)
                ) / np.dot(base_vector, coefficients.T)
                counterfactual = agent + multiplier * base_vector

        lb_valid = counterfactual >= lb
        ub_valid = counterfactual <= ub
        if not (lb_valid.all() and ub_valid.all()):
            warnings.warn(
                "Could not generate a counterfactual to reach the desired threshold."
            )

        for cat_feat, value in zip(self.categorical, categorical_vals):
            counterfactual[cat_feat] = value

        return counterfactual


class WachterRecourse(BaseRecourse):
    """
    Counterfactual generation using Wachter et al. (2017):
    'Counterfactual Explanations without Opening the Black Box'

    Optimizes x' by minimizing:
        L(x') = λ * D(x', x) + (f(x') - y_target)^2
    where D is an L2 distance (optionally feature-scaled) and f is model.predict_proba.
    """

    def __init__(
        self,
        model,
        threshold: float = 0.5,
        categorical: Optional[Union[list, np.ndarray]] = None,
        immutable: Optional[Union[list, np.ndarray]] = None,
        step_direction: Optional[dict] = None,
        y_desired: Union[int, str] = 1,
        # Wachter-specific hyperparameters
        max_iter: int = 1000,
        lambda_param: float = 1e-6,
        lr: float = 0.01,
        tol: float = 1e-4,
        finite_diff_eps: float = 1e-4,
        feature_scale: Optional[Union[np.ndarray, float]] = None,
    ):
        super().__init__(
            model=model,
            threshold=threshold,
            categorical=categorical,
            immutable=immutable,
            step_direction=step_direction,
            y_desired=y_desired,
        )
        self.max_iter = max_iter
        self.lambda_param = lambda_param
        self.lr = lr
        self.tol = tol
        self.finite_diff_eps = finite_diff_eps
        self.feature_scale = feature_scale  # None → estimated from bounds

    # -------------------- core API --------------------
    def _counterfactual(self, agent, action_set):
        import pandas as pd
        x = agent.values.astype(float)
        threshold = float(self.threshold)   # use decision threshold
        actionable_mask = np.array(action_set.actionable, dtype=bool)
    
        # Distance gradient
        def distance_grad(z):
            if self.feature_scale is None:
                return 2 * (z - x)
            else:
                return 2 * (z - x) / (self.feature_scale ** 2)
    
        # Prediction probability function – keep DataFrame for model
        def pred_prob(z):
            X_df = pd.DataFrame([z], columns=agent.index)
            return self.model.predict_proba(X_df)[0, -1]
    
        # Prediction loss gradient via finite differences
        def pred_loss_grad(z):
            eps = self.finite_diff_eps
            p0 = pred_prob(z)
    
            # hinge-style: only penalize if below threshold
            base = (max(0, threshold - p0)) ** 2
    
            g = np.zeros_like(z)
            for i in range(len(z)):
                z_eps = np.array(z, copy=True)
                z_eps[i] += eps
                p_eps = pred_prob(z_eps)
                loss_eps = (max(0, threshold - p_eps)) ** 2
                g[i] = (loss_eps - base) / eps
            return g
    
        z = np.array(x, copy=True)
    
        for _ in range(self.max_iter):
            g = self.lambda_param * distance_grad(z) + pred_loss_grad(z)
            g[~actionable_mask] = 0.0
            z = z - self.lr * g
            z = np.clip(z, action_set.lb, action_set.ub)
            if np.linalg.norm(g) < self.tol:
                break
    
        return pd.Series(z, index=agent.index)

class DiCERecourse(BaseRecourse):
    """
    DiCE-inspired counterfactual generator that:
      • samples candidate directions,
      • enforces actionability & bounds,
      • re-projects to the decision boundary after each adjustment,
      • returns the valid CF with the lowest cost.
    """

    def __init__(
        self,
        model,
        num_counterfactuals: int = 30,
        threshold: float = 0.5,
        categorical: Union[list, np.ndarray] = None,
        immutable: Union[list, np.ndarray] = None,
        step_direction: dict = None,
        y_desired: Union[int, str] = 1,
        cost_fn: Callable = None,
        tol: float = 1e-6,
        max_reproj_iters: int = 50,
    ):
        super().__init__(
            model=model,
            threshold=threshold,
            categorical=categorical,
            immutable=immutable,
            step_direction=step_direction,
            y_desired=y_desired,
        )
        self.num_counterfactuals = num_counterfactuals
        self.cost_fn = cost_fn if cost_fn is not None else self._default_cost
        self.tol = tol
        self.max_reproj_iters = max_reproj_iters

    # ---- cost ----
    def _default_cost(self, x_orig, x_cf):
        """Default cost = L1 distance over non-categorical features."""
        x_o = x_orig.drop(self.categorical, errors="ignore")
        x_c = x_cf.drop(self.categorical, errors="ignore")
        return float(np.sum(np.abs((x_o - x_c).values)))

    # ---- core ----
    def _counterfactual(self, agent, action_set):
        agent_original = agent.copy()

        # Already positive? return original
        if self.model.predict_proba(agent.to_frame().T)[0, -1] >= self.threshold:
            return agent_original

        # Split numeric / categorical
        categorical_vals = agent_original[self.categorical].values if self.categorical else []
        agent_num = agent_original.drop(self.categorical, errors="ignore").copy()

        # model params (assumes linear link used in your _get_coefficients)
        intercept, coefficients, _ = self._get_coefficients()
        w = coefficients.squeeze().astype(float)  # shape (d,)
        b = float(intercept)

        # bounds
        lb = np.array(action_set.lb).flatten().astype(float)
        ub = np.array(action_set.ub).flatten().astype(float)

        # Mask: actionable & allowed step direction relative to w
        usable = np.array([
            (action_set[col].actionable) and (action_set[col].step_direction in [np.sign(w_i), 0])
            for col, w_i in zip(agent_num.index, w)
        ])

        # Zero-out w for non-usable to avoid moving them during projection math
        w_eff = w.copy()
        w_eff[~usable] = 0.0

        candidates = []
        costs = []

        # Helper: (re)project to boundary w·x + b = 0 after clipping and zeroed direction coords
        def project_to_boundary(x, d, w_now):
            """
            Solve for t s.t. w_now·(x + t d) + b = 0 -> t = -(b + w_now·x) / (w_now·d)
            If denominator ≈ 0, return None (cannot move along this direction).
            """
            denom = float(np.dot(w_now, d))
            if abs(denom) < 1e-12:
                return None
            t = -(b + float(np.dot(w_now, x))) / denom
            return x + t * d

        # Generate candidates by random directions
        for _ in range(self.num_counterfactuals):
            # Random direction over usable coords
            d = np.random.randn(agent_num.shape[0])
            d[~usable] = 0.0
            if not np.any(d):
                continue
            d = d / np.linalg.norm(d)

            # Iteratively:
            #  1) project to boundary,
            #  2) clip to [lb, ub],
            #  3) zero directions that hit bounds,
            #  4) renormalize dir and repeat
            x = agent_num.values.astype(float).copy()
            d_now = d.copy()
            w_now = w_eff.copy()

            valid = False
            for _it in range(self.max_reproj_iters):
                # If direction has collapsed, stop
                if not np.any(d_now):
                    break

                # Project to boundary
                x_proj = project_to_boundary(x, d_now, w_now)
                if x_proj is None:
                    break

                # Clip to bounds
                x_clip = np.minimum(np.maximum(x_proj, lb), ub)

                # If clipping did nothing, we’re exactly on boundary (within tol)
                if np.max(np.abs(x_clip - x_proj)) <= self.tol:
                    x = x_clip
                    valid = True
                    break

                # Some features hit bounds — fix them and zero their direction components
                hit = np.where((np.abs(x_clip - x_proj) > self.tol))[0]
                x = x_clip

                # Lock features that hit bounds: set direction & w to 0 so they remain fixed
                d_now[hit] = 0.0
                w_now[hit] = 0.0

                # If nothing left to move, stop
                if not np.any(d_now):
                    # If we happen to already be on/over the boundary, accept
                    # Check using model (safer than linear check if model != pure logistic)
                    p = self.model.predict_proba(
                        (agent_num * 0 + x).to_frame().T
                    )[0, -1]
                    valid = (p >= self.threshold)
                    break

                # Renormalize remaining direction to avoid step shrinking
                norm = np.linalg.norm(d_now)
                if norm > 0:
                    d_now = d_now / norm

            if not valid:
                # Last chance: check classification in case numerical tolerance passed
                x_df = agent_num * 0 + x
                p = self.model.predict_proba(x_df.to_frame().T)[0, -1]
                valid = (p >= self.threshold)

            if not valid:
                continue

            # Build full CF with categorical restored
            cf_num = (agent_num * 0 + x)
            cf = cf_num.copy()
            if self.categorical:
                for cat_feat, value in zip(self.categorical, categorical_vals):
                    cf[cat_feat] = value

            # Cost and keep
            c = self.cost_fn(agent_original, cf)
            candidates.append(cf)
            costs.append(c)

        if not candidates:
            warnings.warn("Could not generate a counterfactual to reach the desired threshold.")
            return agent_original

        # Select best by cost
        best = candidates[int(np.argmin(costs))]

        # Optional: ensure final is on/over threshold
        p_best = self.model.predict_proba(best.to_frame().T)[0, -1]
        if p_best + 1e-12 < self.threshold:
            warnings.warn("Best candidate slipped below threshold due to numerical issues.")
        return best