"""
Environments where the recourse recommender and the predictor are trained
"""

import gymnasium
import numpy as np
import pandas as pd
from data_generation import generate_synthetic_data
from recourse import UstunRecourse, WachterRecourse, DiCERecourse
from copy import deepcopy

# environment with a single candidate; given their current features and a goal score, the action is a counterfactual recommendation to reach said score

# phase 1 of training : the reward depends only on the squared error between the candidate's score and the goal score
# phase 2 of training : the reward depends on both the squared error and the cost of the recommendation
class SingleCandidateEnv(gymnasium.Env):
    def __init__(self, X_test, model, feature_difficulties, difficulty_estimates, difficulty_update_counts, 
                 steps=100, current_idx=0, data_gen_func=generate_synthetic_data, seed=0, phase = 1):
        super(SingleCandidateEnv, self).__init__()

        # Initialize parameters
        self.data_gen_func = data_gen_func
        self.seed = seed
        self.rng = np.random.default_rng(self.seed)
        self.X_test = X_test.reset_index(drop=True)
        self.model = model
        self.current_idx = current_idx
        self.episode_idx = 0
        self.steps = steps
        self.phase = phase

        # Needed for feature difficulty estimation
        self.difficulty_estimates = difficulty_estimates
        self.difficulty_update_counts = difficulty_update_counts
        # True difficulties, unknown to the agent, used to simulate feature changes
        self.feature_difficulties_array = [feature_difficulties[key] for key in sorted(feature_difficulties.keys())]

        # Initialize first candidate
        sample_idx = self.rng.integers(0, len(self.X_test))
        self.X_ = pd.DataFrame([self.X_test.iloc[sample_idx]]).reset_index(drop=True)

        # Define observation and action space
        num_features = self.X_.shape[1]
        self.observation_space = gymnasium.spaces.Dict({
            "goal_score": gymnasium.spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32),
            "features": gymnasium.spaces.Box(low=0.0, high=1.0, shape=(1, num_features), dtype=np.float32),
        })

        # The action is the counterfactual vector
        self.action_space = gymnasium.spaces.Box(low=0.01, high=0.99, shape=(1, num_features), dtype=np.float32)

    def reset(self, seed=None, options=None, external_X=None, external_goal_score=None):
        # Increase the episode count
        self.episode_idx += 1

        # Change the seed (associated with each episode)
        self.seed += 1
        self.rng = np.random.default_rng(self.seed)
    
        self.X_ = external_X.reset_index(drop=True)
    
        # Index associated with the time step
        self.current_idx = 0

        # Dictionary containing the number of times each candidate has applied
        self.num_applied_dict = {idx: 1 for idx in self.X_.index}
    
        # Get scores
        probs = self.model.predict_proba(self.X_)[:, -1]
        self.scores_ = pd.Series(probs, index=self.X_.index)
    
        self.goal_score = external_goal_score
    
        return self._get_observation(), {}
    
    def _generate_counterfactuals(self, action_df):
        
        # Clip the action values to [0,1]
        action_df = action_df.clip(0, 1)

        # Start from current features
        counterfactuals = self.X_.copy()

        # Apply the action as counterfactuals
        counterfactuals.loc[action_df.index] = action_df

        # Ensure counterfactuals are >= current features
        counterfactuals = counterfactuals.where(counterfactuals >= self.X_, self.X_)

        # Compute predicted probabilities
        cf_scores = self.model.predict_proba(counterfactuals)[:, -1]
        cf_scores_series = pd.Series(cf_scores, index=self.X_.index)

        return counterfactuals, cf_scores_series
    
    def _update_candidate_features(self):

        new_factuals = {}

        for candidate_idx in self.X_.index:
            attainability_dict = {}

            # Get the current and counterfactual features for this candidate
            current_features = self.X_.loc[candidate_idx].values
            counterfactual_features = self.counterfactuals.loc[candidate_idx].values

            # Calculate overall difficulties (can be replaced with specific scaling if needed)
            overall_difficulties = counterfactual_features

            # Compute attainability for each feature
            for i in range(self.X_.shape[1]):
                overall_difficulty_value = overall_difficulties[i]
                epsilon = 1e-3  # Avoid division by zero
                norm_dist = np.abs(counterfactual_features[i] - current_features[i])
                attainability = 1 / (epsilon + norm_dist * overall_difficulty_value) - 1
                attainability_dict[i] = attainability

            # Compute probability of applying the feature change
            attainability_values = np.array(list(attainability_dict.values()))
            beta = 0.01
            probability = 1 - np.exp(beta * (1/np.array(self.feature_difficulties_array)) * attainability_values)

            # Decide which feature changes to apply
            apply_change = self.rng.uniform(size=counterfactual_features.shape) < probability
            applied_feature_change = (counterfactual_features - current_features) * apply_change

            # Apply changes and clip to [0,1]
            new_features = np.clip(current_features + applied_feature_change, 0, 1)
            new_factuals[candidate_idx] = new_features

        # Convert the dictionary to a DataFrame
        new_factuals = pd.DataFrame.from_dict(new_factuals, orient='index', columns=self.X_.columns)

        return new_factuals
    
    def _update_difficulty_estimates(self, new_factuals, counterfactuals):
        base_lr = 0.05
        epsilon = 1e-3

        for candidate_idx in self.X_.index:
            current_features = self.X_.loc[candidate_idx].values
            new_features = new_factuals.loc[candidate_idx].values
            counterfactual_features = counterfactuals.loc[candidate_idx].values

            for i in range(self.X_.shape[1]):
                norm_dist = np.abs(counterfactual_features[i] - current_features[i])
                if norm_dist == 0:
                    continue  # no intended change, skip

                # Compute attainability
                attainability = max(0, 1 / (epsilon + norm_dist * counterfactual_features[i]) - 1)

                # Estimated probability of change
                estimated_difficulty = self.difficulty_estimates[i]
                alpha = 1.0  # scaling factor for learning (you can adjust)
                p = 1 - np.exp(-alpha * attainability * (1 / (epsilon + estimated_difficulty)))

                # Did change occur?
                y = int(new_features[i] != current_features[i])

                # Gradient step with decaying learning rate
                gradient = (p - y) * attainability
                lr_i = base_lr / (1 + self.difficulty_update_counts[i])

                self.difficulty_estimates[i] += lr_i * gradient
                self.difficulty_estimates[i] = np.clip(self.difficulty_estimates[i], 0, 1)

                self.difficulty_update_counts[i] += 1


    def step(self, action):

        action_df = pd.DataFrame(action, index=self.X_.index, columns=self.X_.columns)
        counterfactuals, cf_scores_series = self._generate_counterfactuals(action_df)
        self.counterfactuals = counterfactuals.copy()

        # Squared error between the counterfactual score and the goal score
        squared_error = (np.abs(cf_scores_series - self.goal_score)).sum()

        feature_changes = np.abs(counterfactuals - self.X_)
        weighted_feature_changes = (feature_changes * self.difficulty_estimates)
        cost_change = weighted_feature_changes.sum(axis=1)
        cost = cost_change[0]

        # Update candidate features
        new_factuals = self._update_candidate_features()
        self.X_ = new_factuals.copy()

        if self.phase == 1:
            # Update difficulty estimates based on observed changes
            self._update_difficulty_estimates(new_factuals, counterfactuals)

        # Update scores after feature changes
        self.scores_ = pd.Series(self.model.predict_proba(self.X_)[:, -1], index=self.X_.index)

        if self.phase == 1:
            reward = - 5 * squared_error
        elif self.phase == 2:
            epsilon = 0.01  # squared error threshold
            penalty_base = 300  # base multiplier
            if squared_error <= epsilon:
                reward = -10*cost
            else:
                violation_amount = squared_error - epsilon
                penalty = penalty_base * violation_amount
                reward = -10*cost - penalty

        self.squared_error = squared_error
        self.reward = reward
        self.cost = cost

        if self.scores_[0] >= self.goal_score:
            done = True
            self.reward = reward
        elif self.current_idx >= self.steps:
            done = True
            self.reward = reward
        else:
            done = False
            self.reward = reward
        self.current_idx += 1

        self.current_idx += 1
        return self._get_observation(), reward, done, False, {}


    def _get_observation(self):
        goal_score = self.goal_score
        features = self.X_.values
        rejected_scores = self.scores_.loc[self.X_.index].values

        return {
            "goal_score": np.array([goal_score], dtype=np.float32),
            "features": features,
        }
    
class TrainingEnvWrapper(gymnasium.Wrapper):
    def __init__(self, env, X_test):
        super().__init__(env)
        self.X_test = X_test.copy().reset_index(drop=True)

    def reset(self, **kwargs):
        if len(self.X_test) == 0:
            raise ValueError("No more candidates left.")
            
        # Sample the single candidate's features
        sample_idx = self.env.rng.choice(len(self.X_test), size=1, replace=False)
        self.X_ = self.X_test.iloc[sample_idx].reset_index(drop=True)

        # sample a goal score higher than the candidate's current score
        probs = self.env.model.predict_proba(self.X_)[:, -1] #scores
        scores_ = pd.Series(probs, index=self.X_.index)
        current_score = scores_[0]
        self.goal_score = current_score + (1 - current_score) * np.random.rand()
        
        # Reset the environment with the sampled candidate and goal score
        obs, info = self.env.reset(external_X=self.X_, external_goal_score=self.goal_score, **kwargs)
        return obs, info
    
# Wrapper needed to specify the inputs: candidate's features and goal score
class TestEnvWrapper(gymnasium.Wrapper):
    def __init__(self, env, X_, goal_score):
        super().__init__(env)
        self.X_ = X_.reset_index(drop=True)
        self.goal_score = goal_score

    def reset(self, **kwargs):
        obs, info = self.env.reset(external_X=self.X_, external_goal_score=self.goal_score, **kwargs)
        return obs, info
    
# Function to compute Gini coefficient
def gini(array):
    # Mean absolute difference / (2*mean)
    array = np.array(array)
    if np.amin(array) < 0:
        array -= np.amin(array)  # values must be non-negative
    mean = np.mean(array)
    if mean == 0:
        return 0
    diff_sum = np.abs(np.subtract.outer(array, array)).mean()
    return diff_sum / (2 * mean)


    
class CandidateEnv(gymnasium.Env):
    def __init__(
        self, 
        X_test, 
        model, 
        feature_difficulties,
        base_single_env,
        rl_model2,
        steps=100,
        decay_factor_distance=0.7, 
        decay_factor_num=0.03, 
        decay_factor_combination=0.07,
        current_idx=0,
        threshold=9, 
        growth_k=10, 
        data_gen_func=generate_synthetic_data, 
        seed=0, 
        t_validity=1,
        beta = 0.05,
        alpha =  7,
        tau = 5,
        method = "ours",
        baseline = False
    ):
        super().__init__()

        # Store parameters
        self.X_test = X_test.reset_index(drop=True)
        self.model = model
        self.steps = steps
        self.decay_factor_distance = decay_factor_distance
        self.decay_factor_num = decay_factor_num
        self.decay_factor_combination = decay_factor_combination
        self.current_idx = current_idx
        self.threshold = threshold
        self.growth_k = growth_k
        self.data_gen_func = data_gen_func
        self.seed = seed
        self.rng = np.random.default_rng(self.seed)
        self.t_validity = t_validity
        self.base_single_env = base_single_env
        self.rl_model2 = rl_model2
        self.beta = beta
        self.method = method
        self.alpha = alpha
        self.tau = tau
        self.baseline = baseline

        # Feature difficulties array
        self.feature_difficulties_array = [feature_difficulties[key] for key in sorted(feature_difficulties.keys())]
        self.model_weights = np.abs(self.model.coef_).flatten()

        # Metrics for evaluation
        self.cost_sum = []
        self.ginis_sum = []
        self.reliabilities_sum = []
        self.implementing_sum = []

        # Initialize candidate pool
        sample_idxs = self.rng.integers(0, len(self.X_test), size=20)
        self.X_ = self.X_test.iloc[sample_idxs].reset_index(drop=True)

        # Not seen by the agent, part of the state but not the observation, contains the features of all candidates currently in the system
        self.all_candidates_ = self.X_.copy()
        # Not seen by the agent, part of the state but not the observation, contains the recommendations given to all candidates currently in the system after applying counterfactuals
        self.counterfactuals_all = self.X_.copy()

        # Dictionaries to keep track of applications
        self.last_application = {idx: 0 for idx in self.X_.index}
        self.num_applied_dict_all = {idx: 0 for idx in self.X_.index}

        # Other state variables
        self.episode_idx = 0
        self.threshold_ = None
        self.max_candidates = 200
        self.max_id = self.X_.index.max()
        num_features = self.X_.shape[1]

        # Observation space
        self.observation_space = gymnasium.spaces.Dict({
            # Indices of all candidates currently in the system, according to the agent's knowledge (i.e., those last rejected in the last t_validity time-steps)
            "past_indices": gymnasium.spaces.Box(low=-1, high=np.inf, shape=(self.max_candidates,), dtype=np.int32),
            # Indices of the candidates applying at the current time step
            "reapplying_indices": gymnasium.spaces.Box(low=-1, high=np.inf, shape=(self.max_candidates,), dtype=np.int32),
            # Last-seen features of all candidates currently in the system, according to the agent's knowledge (i.e., those last rejected in the last t_validity time-steps)
            "past_features": gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(self.max_candidates, num_features), dtype=np.float32),
            # Current features of the candidates applying at the current time step
            "features": gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(self.max_candidates, num_features), dtype=np.float32),
            # Recommended goal scores for all candidates currently in the system, according to the agent's knowledge (i.e., those last rejected in the last t_validity time-steps)
            "goal_scores": gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(self.max_candidates, 1), dtype=np.float32),
            # Current scores of the candidates applying at the current time step
            "scores": gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(self.max_candidates,), dtype=np.float32),
            # Last-seen scores of all candidates currently in the system, according to the agent's knowledge (i.e., those last rejected in the last t_validity time-steps)
            "past_scores": gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(self.max_candidates,), dtype=np.float32),
            # Number of times each candidate has applied
            "num_applied_all": gymnasium.spaces.Box(low=-1, high=1000, shape=(self.max_candidates,), dtype=np.int32),
            # Last time step in which each candidate applied
            "last_application_all": gymnasium.spaces.Box(low=-1, high=1000, shape=(self.max_candidates,), dtype=np.int32),
            # Current decision threshold
            "threshold": gymnasium.spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32),
            # Binary outcomes of the candidates applying at the current time step
            "outcomes": gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(self.max_candidates,), dtype=np.float32),
        })

        # Action space: Counterfactual goal score
        self.action_space = gymnasium.spaces.Box(low=0.01, high=0.99, shape=(1,), dtype=np.float32)

    def reset(self, seed=None, options=None):
        # Reset episode-specific state
        self.current_idx = 0
        self.episode_idx += 1
        self.seed += 1
        self.rng = np.random.default_rng(self.seed)

        # Reset running sums for metrics
        self.cost_running_sum = 0
        self.gini_running_sum = 0
        self.reliability_running_sum = 0
        self.implementing_running_sum = 0

        # Sample initial candidates for this episode, initialize additional variables
        sample_idxs = self.rng.integers(0, len(self.X_test), size=20)
        self.X_ = self.X_test.iloc[sample_idxs].reset_index(drop=True)
        self.all_candidates_ = self.X_.copy()
        self.counterfactuals_all = self.X_.copy()
        self.past_features = self.X_.copy()

        # Reset application tracking dictionaries
        self.last_application = {idx: 0 for idx in self.X_.index}
        self.num_applied_dict_all = {idx: 0 for idx in self.X_.index}

        # Initialize goal scores for all candidates
        self.goal_scores = pd.DataFrame(index=self.X_.index, columns=["goal_score"])
        self.goal_scores["goal_score"] = 0.5

        # Update max_id for new candidate indices
        self.max_id = self.all_candidates_.index.max()

        # Compute initial scores and threshold for candidate selection
        threshold_index_ = max(0, self.X_.shape[0] - self.threshold)
        probs = self.model.predict_proba(self.X_)[:, -1]
        pred = np.zeros(probs.shape, dtype=int)
        idx = np.argsort(probs)[threshold_index_:]
        pred[idx] = 1
        probabilities = pd.Series(probs, index=self.X_.index)
        probabilities_sorted = probabilities.sort_values(ascending=True)
        self.threshold_ = (
            probabilities_sorted.iloc[threshold_index_]
            if threshold_index_ != 0
            else np.nan
        )
        self.outcome_ = pd.Series(pred, index=self.X_.index)
        self.scores_ = pd.Series(probabilities, index=self.X_.index)
        self.past_scores = self.scores_.copy()

        return self._get_observation(), {}
    
    def _pad_array(self, arr, target_shape, is_matrix=False):
        """Pad array or matrix with zeros to match target_shape."""
        if is_matrix:
            if arr.shape[0] < target_shape[0]:
                padding = np.zeros((target_shape[0] - arr.shape[0], target_shape[1]))
                arr = np.vstack([arr, padding])
        else:
            if len(arr) < target_shape:
                padding = np.zeros(target_shape - len(arr))
                arr = np.concatenate([arr, padding])
        return arr


    def step(self, action):  
        # Identify accepted and non-accepted candidates and update dictionaries
        self._update_candidates_after_acceptance()
        # Remove candidates that haven't applied in more than t_validity time steps
        self._remove_expired_candidates()

        # Update the time step index
        self.current_idx += 1
        # Extract the single goal score
        single_action = action[0]

        # If baseline, set the goal score to the last-seen threshold
        if self.baseline:
            if np.isnan(self.threshold_):
                single_action = 0.5
            else:
                single_action = self.threshold_

        # Update the goal score for those who have just received the recommendation
        self._update_goal_scores(single_action)
        # Get counterfactuals and scores

        if self.method == "ours":
            cf_scores_all_series = self._generate_counterfactuals_ours(single_action)
        elif self.method == "Ustun":
            cf_scores_all_series = self._generate_counterfactuals_Ustun(single_action)
        elif self.method == "Wachter":
            cf_scores_all_series = self._generate_counterfactuals_Wachter(single_action)
        elif self.method == "DiCE":
            cf_scores_all_series = self._generate_counterfactuals_DiCE(single_action)

        # Extract scores for current candidates
        curr_scores_all = self.model.predict_proba(self.all_candidates_)[:, -1]
        curr_scores_all_series = pd.Series(curr_scores_all, index=self.all_candidates_.index)

        # Compute probabilities of leaving for all candidates
        leaving_probabilities = self._compute_leaving_probabilities(curr_scores_all_series, cf_scores_all_series)

        # Choose the candidates that leave based on the probabilities
        leaving_indices = [idx for idx in self.all_candidates_.index if self.rng.uniform() < leaving_probabilities[idx]]
        staying_indices_all = [idx for idx in self.all_candidates_.index if idx not in leaving_indices]
        
        # Remove leaving candidates from the state variables
        self.all_candidates_ = self.all_candidates_.loc[staying_indices_all].copy()
        self.counterfactuals_all = self.counterfactuals_all.loc[staying_indices_all].copy()

        # Candidates that stay update their features based on the recommendation
        self._update_candidate_features()

        # Candidates in the environment decide whether to reapply
        if len(self.all_candidates_) > 0:
            curr_scores_all = self.model.predict_proba(self.all_candidates_)[:, -1]
            curr_scores_all_series = pd.Series(curr_scores_all, index=self.all_candidates_.index)
            reapply_probabilities = self._compute_reapply_probabilities(curr_scores_all_series, cf_scores_all_series)
            reapplying_indices = [idx for idx in self.all_candidates_.index if self.rng.uniform() < reapply_probabilities[idx]]
        else:
            reapplying_indices = []

        # Update X_ to contain only the candidates that are reapplying
        self.X_ = self.all_candidates_.loc[reapplying_indices]

        # Sample new candidates to add to the environment
        sample_idxs = self.rng.integers(0, len(self.X_test), size=self.growth_k)
        new_agents = self.X_test.iloc[sample_idxs].reset_index(drop=True)
        new_agents.index = range(self.max_id + 1, self.max_id + new_agents.shape[0] + 1)
        self.max_id = new_agents.index.max()
        
        # Initialize application count for new agents
        for idx in new_agents.index:
            self.num_applied_dict_all[idx] = 1

        if len(self.X_) > 0:
            # Concatenate new agents with existing candidates
            self.X_ = pd.concat([self.X_, new_agents])
        else:
            self.X_ = new_agents.copy()

        if len(self.all_candidates_) > 0:
            # Concatenate new agents with existing candidates
            self.all_candidates_ = pd.concat([self.all_candidates_, new_agents])
        else:
            self.all_candidates_ = new_agents.copy()

        # self.threshold is the number of candidates that get accepted
        threshold_index_ = self.X_.shape[0] - self.threshold
        probs = self.model.predict_proba(self.X_)[:, -1]
        pred = np.zeros(probs.shape, dtype=int)
        idx = np.argsort(probs)[threshold_index_:]
        pred[idx] = 1

        probabilities = pd.Series(probs, index=self.X_.index)

        # Sort probabilities in ascending order and then select the threshold
        probabilities_sorted = probabilities.sort_values(ascending=True)
        self.threshold_ = probabilities_sorted.iloc[threshold_index_] if threshold_index_ != 0 else np.nan

        self.outcome_, self.scores_ = (pd.Series(pred, index=self.X_.index), pd.Series(probabilities, index=self.X_.index))

        # Update existing indices
        self.past_scores.update(self.scores_)
        # Add new indices that are not already in past_scores
        self.past_scores = pd.concat([self.past_scores, self.scores_.loc[new_agents.index]])

        # Compute quantities needed for the reward
        reward = self._compute_reward(new_agents, cf_scores_all_series)

        # Update past features with the current features of the candidates that just applied
        self.past_features.update(self.X_)
        if len(self.past_features) > 0:
            self.past_features = pd.concat([self.past_features, new_agents])
        else:
            self.past_features = new_agents.copy()

        # Check stopping condition
        done = self.current_idx >= self.steps
        
        self.cost_running_sum += self.cost_change
        self.gini_running_sum += self.gini
        self.reliability_running_sum += self.recourse_reliability
        self.implementing_running_sum += self.portion_implementing
        if done:
            self.ginis_sum.append(self.gini_running_sum)
            self.cost_sum.append(self.cost_running_sum)
            self.reliabilities_sum.append(self.reliability_running_sum)
            self.implementing_sum.append(self.implementing_running_sum)

        return self._get_observation(), reward, done, False, {}


    def _update_candidates_after_acceptance(self):
        # Identify accepted and non-accepted candidates
        accepted_candidates = self.outcome_[self.outcome_ == 1].index
        non_accepted_candidates = self.outcome_[self.outcome_ == 0].index

        # update dictionaries
        for idx in non_accepted_candidates:
            self.last_application[idx] = self.current_idx
            self.num_applied_dict_all[idx] = +1
        for idx in accepted_candidates:
            self.num_applied_dict_all.pop(idx, None)
            self.last_application.pop(idx, None)

        # remove accepted candidates from dataframes
        self.X_ = self.X_.loc[non_accepted_candidates].copy()
        self.all_candidates_.drop(index=accepted_candidates, errors="ignore", inplace=True)
        self.counterfactuals_all.drop(index=accepted_candidates, errors="ignore", inplace=True)
        self.past_features.drop(index=accepted_candidates, errors="ignore", inplace=True)
        self.goal_scores.drop(index=accepted_candidates, errors="ignore", inplace=True)
        self.scores_.drop(index=accepted_candidates, errors='ignore', inplace=True)
        self.past_scores.drop(index=accepted_candidates, errors='ignore', inplace=True)


    def _remove_expired_candidates(self):
        # if a candidate hasn't applied in more than t_validity time steps, it means that they left
        out_of_promise_indices = [
            idx for idx in self.past_features.index
            if (self.current_idx - self.last_application.get(idx, -np.inf)) >= self.t_validity
        ]

        # remove candidates that haven't applied in more than t_validity time steps
        self.past_features.drop(index=out_of_promise_indices, errors="ignore", inplace=True)
        self.goal_scores.drop(index=out_of_promise_indices, errors="ignore", inplace=True)
        self.past_scores.drop(index=out_of_promise_indices, errors='ignore', inplace=True)
        for idx in out_of_promise_indices:
            self.num_applied_dict_all.pop(idx, None)
            self.last_application.pop(idx, None)


    def _update_goal_scores(self, single_action):
        # update the goal score for those who have just received the recommendation
        updated_scores = pd.DataFrame(single_action, index=self.X_.index, columns=["goal_score"])
        self.goal_scores.update(updated_scores)
        missing_indices = self.X_.index.difference(self.goal_scores.index)
        new_rows = updated_scores.loc[missing_indices]
        if len(self.goal_scores) > 0:
            # Concatenate new agents with existing candidates
            self.goal_scores = pd.concat([self.goal_scores, new_rows])
        else:
            self.goal_scores = new_rows.copy()

    def _generate_counterfactuals_ours(self, single_action):
        # get the recommendations for the rejected candidates
        counterfactuals = self.X_.copy()
        self.cf_indices = []

        # Use the Test Wrapper to provide the right inputs to the RL agent
        # Concatenate the actions of the pre-trained agent to form the matrix of counterfactuals
        for idx, row in self.X_.iterrows():
            candidate_df = row.to_frame().T
            if self.model.predict_proba(candidate_df)[:, -1] < single_action:
                test_env = TestEnvWrapper(self.base_single_env, candidate_df, single_action)
                obs, _ = test_env.reset()
                action_array, _ = self.rl_model2.predict(obs, deterministic=True)
                counterfactuals.loc[idx] = np.clip(action_array, 0, 1)
                self.cf_indices.append(idx)
            else:
                counterfactuals.loc[idx] = row

        # Ensure counterfactuals are >= current features
        counterfactuals_new = counterfactuals.where(counterfactuals >= self.X_, self.X_)
        if self.cf_indices:
            cf_scores = self.model.predict_proba(counterfactuals_new.loc[self.cf_indices])[:, -1]
            self.gini = gini(cf_scores)
        else:
            self.gini = 0

        # Compute the cost of the recommendations
        feature_changes = np.abs(counterfactuals_new - self.X_)
        weighted_feature_changes = feature_changes * self.feature_difficulties_array
        self.cost_change = (weighted_feature_changes.sum(axis=1)).iloc[0]

        # Update existing rows in counterfactuals_all
        self.counterfactuals_all.update(counterfactuals)
        new_rows = counterfactuals.loc[~counterfactuals.index.isin(self.counterfactuals_all.index)]
        self.counterfactuals_all = pd.concat([self.counterfactuals_all, new_rows])
        self.counterfactuals_all = self.counterfactuals_all.where(self.counterfactuals_all >= self.all_candidates_, self.all_candidates_)

        # Extract scores and outcomes for all candidates
        cf_scores_all = self.model.predict_proba(self.counterfactuals_all)[:, -1]
        cf_scores_all_series = pd.Series(cf_scores_all, index=self.all_candidates_.index)
        return cf_scores_all_series
    

    def _generate_counterfactuals_Ustun(self, single_action):
        recourse = UstunRecourse(model=self.model, threshold = single_action, n_features=self.X_.shape[1], categorical=[], 
                                immutable=[])
        recourse.set_actions(self.X_)
        recourse.action_set_.ub = 0.999
        recourse.action_set_.lb = 0.001
        
        rec = deepcopy(recourse)
        counterfactuals = rec.counterfactual(self.X_)  # Generate counterfactuals

        self.cf_indices = []  # keep track of indices where we generated counterfactuals
        
        for idx, row in self.X_.iterrows():
            candidate_df = row.to_frame().T
            if self.model.predict_proba(candidate_df)[:, -1] < single_action:
                self.cf_indices.append(idx)  # store index
        
        counterfactuals_new = counterfactuals.where(counterfactuals >= self.X_, self.X_)
        # Scores for current step's counterfactuals
        cf_scores = self.model.predict_proba(counterfactuals_new)[:, -1]
        # Gini index of these counterfactual scores
        if self.cf_indices:  # make sure it's not empty
            cf_scores = self.model.predict_proba(counterfactuals_new.loc[self.cf_indices])[:, -1]
            self.gini = gini(cf_scores)
        else:
            self.gini = 0   

        feature_changes = np.abs(counterfactuals_new - self.X_)
        weighted_feature_changes = feature_changes * self.feature_difficulties_array

        self.counterfactuals_all.update(counterfactuals)
        # Append new rows that are not in counterfactuals_all
        new_rows = counterfactuals.loc[~counterfactuals.index.isin(self.counterfactuals_all.index)]
        self.counterfactuals_all = pd.concat([self.counterfactuals_all, new_rows])
        self.counterfactuals_all = self.counterfactuals_all.where(self.counterfactuals_all >= self.all_candidates_, self.all_candidates_)
        
        cf_scores_all = self.model.predict_proba(self.counterfactuals_all)[:, -1]
        cf_scores_all_series = pd.Series(cf_scores_all, index=self.all_candidates_.index)

        # Ensure counterfactuals are >= current features
        counterfactuals_new = counterfactuals.where(counterfactuals >= self.X_, self.X_)
        if self.cf_indices:
            cf_scores = self.model.predict_proba(counterfactuals_new.loc[self.cf_indices])[:, -1]
            self.gini = gini(cf_scores)
        else:
            self.gini = 0

        # Compute the cost of the recommendations
        feature_changes = np.abs(counterfactuals_new - self.X_)
        weighted_feature_changes = feature_changes * self.feature_difficulties_array
        self.cost_change = (weighted_feature_changes.sum(axis=1)).iloc[0]

        # Update existing rows in counterfactuals_all
        self.counterfactuals_all.update(counterfactuals)
        new_rows = counterfactuals.loc[~counterfactuals.index.isin(self.counterfactuals_all.index)]
        self.counterfactuals_all = pd.concat([self.counterfactuals_all, new_rows])
        self.counterfactuals_all = self.counterfactuals_all.where(self.counterfactuals_all >= self.all_candidates_, self.all_candidates_)

        # Extract scores and outcomes for all candidates
        cf_scores_all = self.model.predict_proba(self.counterfactuals_all)[:, -1]
        cf_scores_all_series = pd.Series(cf_scores_all, index=self.all_candidates_.index)
        return cf_scores_all_series
    
    def _generate_counterfactuals_Wachter(self, single_action):
        recourse = WachterRecourse(
            model=self.model,
            threshold=single_action,
            categorical=[],
            immutable=[],
            lambda_param=0.01,  
            lr=0.1,            
            max_iter=200,        
            tol=5e-4             
        )

        # Set action set and bounds just like before
        recourse.set_actions(self.X_)
        recourse.action_set_.ub = 0.999
        recourse.action_set_.lb = 0.001
        
        rec = deepcopy(recourse)
        counterfactuals = rec.counterfactual(self.X_)

        self.cf_indices = []  # keep track of indices where we generated counterfactuals
        
        for idx, row in self.X_.iterrows():
            candidate_df = row.to_frame().T
            if self.model.predict_proba(candidate_df)[:, -1] < single_action:
                self.cf_indices.append(idx)  # store index
        
        counterfactuals_new = counterfactuals.where(counterfactuals >= self.X_, self.X_)
        # Scores for current step's counterfactuals
        cf_scores = self.model.predict_proba(counterfactuals_new)[:, -1]
        # Gini index of these counterfactual scores
        if self.cf_indices:  # make sure it's not empty
            cf_scores = self.model.predict_proba(counterfactuals_new.loc[self.cf_indices])[:, -1]
            self.gini = gini(cf_scores)
        else:
            self.gini = 0   

        feature_changes = np.abs(counterfactuals_new - self.X_)
        weighted_feature_changes = feature_changes * self.feature_difficulties_array

        self.counterfactuals_all.update(counterfactuals)
        # Append new rows that are not in counterfactuals_all
        new_rows = counterfactuals.loc[~counterfactuals.index.isin(self.counterfactuals_all.index)]
        self.counterfactuals_all = pd.concat([self.counterfactuals_all, new_rows])
        self.counterfactuals_all = self.counterfactuals_all.where(self.counterfactuals_all >= self.all_candidates_, self.all_candidates_)
        
        cf_scores_all = self.model.predict_proba(self.counterfactuals_all)[:, -1]
        cf_scores_all_series = pd.Series(cf_scores_all, index=self.all_candidates_.index)

        # Ensure counterfactuals are >= current features
        counterfactuals_new = counterfactuals.where(counterfactuals >= self.X_, self.X_)
        if self.cf_indices:
            cf_scores = self.model.predict_proba(counterfactuals_new.loc[self.cf_indices])[:, -1]
            self.gini = gini(cf_scores)
        else:
            self.gini = 0

        # Compute the cost of the recommendations
        feature_changes = np.abs(counterfactuals_new - self.X_)
        weighted_feature_changes = feature_changes * self.feature_difficulties_array
        self.cost_change = (weighted_feature_changes.sum(axis=1)).iloc[0]

        # Update existing rows in counterfactuals_all
        self.counterfactuals_all.update(counterfactuals)
        new_rows = counterfactuals.loc[~counterfactuals.index.isin(self.counterfactuals_all.index)]
        self.counterfactuals_all = pd.concat([self.counterfactuals_all, new_rows])
        self.counterfactuals_all = self.counterfactuals_all.where(self.counterfactuals_all >= self.all_candidates_, self.all_candidates_)

        # Extract scores and outcomes for all candidates
        cf_scores_all = self.model.predict_proba(self.counterfactuals_all)[:, -1]
        cf_scores_all_series = pd.Series(cf_scores_all, index=self.all_candidates_.index)
        return cf_scores_all_series
    
    def _generate_counterfactuals_DiCE(self, single_action):
        recourse = DiCERecourse(
            model=self.model,
            threshold=single_action,              # same threshold
            categorical=[],                       # same categorical
            immutable=[],                         # same immutable
            num_counterfactuals=30,               # new DiCE-specific arg (tunable)
        )
        
        # set up action set the same way
        recourse.set_actions(self.X_)
        recourse.action_set_.ub = 0.999
        recourse.action_set_.lb = 0.001
        
        # deepcopy is fine if you need it
        rec = deepcopy(recourse)
        
        # Generate counterfactuals (still works the same)
        counterfactuals = rec.counterfactual(self.X_)

        self.cf_indices = []  # keep track of indices where we generated counterfactuals
        
        for idx, row in self.X_.iterrows():
            candidate_df = row.to_frame().T
            if self.model.predict_proba(candidate_df)[:, -1] < single_action:
                self.cf_indices.append(idx)  # store index
        
        counterfactuals_new = counterfactuals.where(counterfactuals >= self.X_, self.X_)
        # Scores for current step's counterfactuals
        cf_scores = self.model.predict_proba(counterfactuals_new)[:, -1]
        # Gini index of these counterfactual scores
        if self.cf_indices:  # make sure it's not empty
            cf_scores = self.model.predict_proba(counterfactuals_new.loc[self.cf_indices])[:, -1]
            self.gini = gini(cf_scores)
        else:
            self.gini = 0   

        feature_changes = np.abs(counterfactuals_new - self.X_)
        weighted_feature_changes = feature_changes * self.feature_difficulties_array

        self.counterfactuals_all.update(counterfactuals)
        # Append new rows that are not in counterfactuals_all
        new_rows = counterfactuals.loc[~counterfactuals.index.isin(self.counterfactuals_all.index)]
        self.counterfactuals_all = pd.concat([self.counterfactuals_all, new_rows])
        self.counterfactuals_all = self.counterfactuals_all.where(self.counterfactuals_all >= self.all_candidates_, self.all_candidates_)
        
        cf_scores_all = self.model.predict_proba(self.counterfactuals_all)[:, -1]
        cf_scores_all_series = pd.Series(cf_scores_all, index=self.all_candidates_.index)

        # Ensure counterfactuals are >= current features
        counterfactuals_new = counterfactuals.where(counterfactuals >= self.X_, self.X_)
        if self.cf_indices:
            cf_scores = self.model.predict_proba(counterfactuals_new.loc[self.cf_indices])[:, -1]
            self.gini = gini(cf_scores)
        else:
            self.gini = 0

        # Compute the cost of the recommendations
        feature_changes = np.abs(counterfactuals_new - self.X_)
        weighted_feature_changes = feature_changes * self.feature_difficulties_array
        self.cost_change = (weighted_feature_changes.sum(axis=1)).iloc[0]

        # Update existing rows in counterfactuals_all
        self.counterfactuals_all.update(counterfactuals)
        new_rows = counterfactuals.loc[~counterfactuals.index.isin(self.counterfactuals_all.index)]
        self.counterfactuals_all = pd.concat([self.counterfactuals_all, new_rows])
        self.counterfactuals_all = self.counterfactuals_all.where(self.counterfactuals_all >= self.all_candidates_, self.all_candidates_)

        # Extract scores and outcomes for all candidates
        cf_scores_all = self.model.predict_proba(self.counterfactuals_all)[:, -1]
        cf_scores_all_series = pd.Series(cf_scores_all, index=self.all_candidates_.index)
        return cf_scores_all_series


    def _compute_leaving_probabilities(self, curr_scores_all_series, cf_scores_all_series):
        # Compute probabilities of leaving for all candidates
        leaving_probabilities = {}
        for idx in self.all_candidates_.index:
            distance_to_threshold = cf_scores_all_series[idx] - curr_scores_all_series[idx]
            stay_prob = np.exp(-(
                self.decay_factor_distance * distance_to_threshold
                + self.decay_factor_num * self.num_applied_dict_all[idx]
                + self.decay_factor_combination * self.num_applied_dict_all[idx] * distance_to_threshold
            ))
            stay_prob = np.clip(stay_prob, 0, 1)
            leaving_probabilities[idx] = 1 - stay_prob
        return leaving_probabilities


    def _update_candidate_features(self):
        # Candidates that stay update their features based on the recommendation
        new_factuals = {}
        for candidate_idx in self.all_candidates_.index:
            attainability_dict = {}
            current_features = self.all_candidates_.loc[candidate_idx].values
            counterfactual_features = self.counterfactuals_all.loc[candidate_idx].values

            magnitude_difficulties = counterfactual_features
            overall_difficulties = magnitude_difficulties

            # Iterate over each feature
            for i in range(self.all_candidates_.shape[1]):
                overall_difficulty_value = overall_difficulties[i]
                epsilon = 1e-3
                norm_dist = np.abs(counterfactual_features[i] - current_features[i])
                attainability = 1 / (epsilon + norm_dist * overall_difficulty_value) - 1
                attainability_dict[i] = attainability

            attainability_values = np.array(list(attainability_dict.values()))
            probability = 1 - np.exp(-self.beta * (1/np.array(self.feature_difficulties_array)) * attainability_values)
            apply_change = self.rng.uniform(size=counterfactual_features.shape) < probability
            applied_feature_change = (counterfactual_features - current_features) * apply_change
            new_features = current_features + applied_feature_change
            new_features = np.clip(new_features, 0, 1)
            new_factuals[candidate_idx] = new_features

        self.all_candidates_ = pd.DataFrame.from_dict(new_factuals, orient='index', columns=self.all_candidates_.columns)


    def _compute_reapply_probabilities(self, curr_scores_all_series, cf_scores_all_series):
        # Candidates in the environment decide whether to reapply based on the improvement in score and the time since last application
        reapply_probabilities = {}
        for idx in self.all_candidates_.index:
            distance_to_threshold = cf_scores_all_series[idx] - curr_scores_all_series[idx]
            decay = 7
            base_prob = np.exp(-decay * distance_to_threshold)
            elapsed_time = self.current_idx - self.last_application[idx]
            scaling_factor = min(1, elapsed_time / self.t_validity)
            adjusted_prob = (1 - scaling_factor) * base_prob + scaling_factor * 1
            reapply_probabilities[idx] = adjusted_prob
        return reapply_probabilities


    def _compute_reward(self, new_agents, cf_scores_all_series):
        # Candidates that perfectly implemented the recourse and got accepted
        accepted_candidates = self.outcome_[self.outcome_ == 1].index
        valid_recourse_candidates = [
            idx for idx in accepted_candidates
            if idx not in new_agents.index and (self.scores_[idx] - cf_scores_all_series[idx]) >= 0
        ]

        # Candidates that perfectly implemented the recourse
        implementing_recourse_candidates = [
            idx for idx in self.X_.index
            if idx not in new_agents.index and (self.scores_[idx] - cf_scores_all_series[idx]) >= 0
        ]

        num_implementing = len(implementing_recourse_candidates)
        num_valid_recourse = len(valid_recourse_candidates)

        # Recourse reliability
        if num_implementing == 0:
            self.recourse_reliability = 1
        else:
            self.recourse_reliability = num_valid_recourse / num_implementing

        # Recourse Feasibility
        if len(self.past_features) == 0:
            self.portion_implementing = 0
        else:
            self.portion_implementing = num_implementing / len(self.past_features)

        # Reward function
        reward_term1 = 1 + 0.90 * np.log(self.recourse_reliability + 0.01)
        reward_term2 = 1 + 0.90 * np.log(self.portion_implementing + 0.01)
        return self.alpha * reward_term1 + self.tau * reward_term2


    def _get_observation(self):
        # Pad all arrays to have the same length of max_candidates
        features = self._pad_array(self.X_.values, (self.max_candidates, self.X_.shape[1]), is_matrix=True)
        past_features = self._pad_array(self.past_features.values, (self.max_candidates, self.X_.shape[1]), is_matrix=True)
        reapplying_indices = self._pad_array(self.X_.index.to_numpy(), self.max_candidates)
        past_indices = self._pad_array(self.past_features.index.to_numpy(), self.max_candidates)
        scores = self._pad_array(self.scores_.values, self.max_candidates)
        past_scores = self._pad_array(self.past_scores.values, self.max_candidates)
        goal_scores = self._pad_array(self.goal_scores.values, (self.max_candidates, 1), is_matrix=True)
        outcomes = self._pad_array(self.outcome_.values, self.max_candidates)
        num_applied_all = self._pad_array(pd.Series(self.num_applied_dict_all).values, self.max_candidates)
        last_application_all = self._pad_array(pd.Series(self.last_application).values, self.max_candidates)

        obs = {
            "past_indices": past_indices,
            "reapplying_indices": reapplying_indices,
            "past_features": past_features,
            "features": features,
            "goal_scores": goal_scores,
            "scores": scores,
            "past_scores": past_scores,
            "num_applied_all": num_applied_all,
            "last_application_all": last_application_all,
            "threshold": np.array([self.threshold_], dtype=np.float32),
            "outcomes": outcomes,
        }
        return obs