from sklearn.neighbors import KDTree, LocalOutlierFactor
from sklearn.metrics import pairwise_distances
from scipy.spatial.distance import cdist, pdist
import numpy as np
import re

class RecourseEvaluator:
    def __init__(self, test_inputs, test_model, dataset_train):
        self.test_inputs = test_inputs
        self.test_model = test_model
        self.dataset_train = dataset_train
        if len(self.dataset_train) > 5000:
            np.random.shuffle(self.dataset_train)
        self.dataset_train = self.dataset_train[:5000]
        # construct test inputs for robustness against input changes
        self.test_rob_inputs = np.zeros((len(self.test_inputs), 5, self.test_inputs.shape[1]))
        for i, x in enumerate(self.test_inputs):
            gaussian_samples = np.random.normal(x, 0.05, (5, len(x)))
            self.test_rob_inputs[i] = gaussian_samples

    def evaluate_validity(self, ces, multi=False):
        if not multi:
            return np.round(np.sum(self.test_model.predict(ces) == 1) / len(self.test_inputs), 4)
        else:
            val = 0
            for i, ce_group in enumerate(ces):
                if np.allclose(ce_group, np.ones((5, self.test_inputs.shape[1])) * -1):
                    continue
                if (self.test_model.predict(ce_group) == 1).any():
                    val += 1
            return np.round(val / len(self.test_inputs), 4)
        
    def evaluate_plausibility(self, ces, multi=False):
        lof_score = 0
        total_points = 0
        for ce in ces:
            if not multi:
                self.lof = LocalOutlierFactor(n_neighbors=10)
                self.lof.fit(np.concatenate((ce.reshape(1, -1), self.dataset_train), axis=0))
                lof_score += -1 * self.lof.negative_outlier_factor_[0]
            else:
                for point in ce:
                    total_points += 1
                    self.lof = LocalOutlierFactor(n_neighbors=10)
                    self.lof.fit(np.concatenate((point.reshape(1, -1), self.dataset_train), axis=0))
                    lof_score += -1 * self.lof.negative_outlier_factor_[0]
        to_divide = len(self.test_inputs)
        if multi:
            to_divide = total_points
        return np.round(lof_score / to_divide, 4)
    
    def normalised_l2(self, xp, x):
        if xp.ndim == 2:
            sum_sq_diff = np.sum((xp - x)**2, axis=1)
            l2_distances = np.sqrt(sum_sq_diff)
            normalized_distances = l2_distances / xp.shape[1]
            return np.round(np.min(normalized_distances), 4)
        else:
            l2_distance = np.sqrt(np.sum((xp - x)**2))
            return np.round(l2_distance / xp.shape[0], 4)

    def normalised_l1(self, xp, x):
        if xp.ndim == 2:
            l1_distances = np.sum(np.abs(xp - x), axis=1)
            normalized_distances = l1_distances / xp.shape[-1]
            return np.round(np.min(normalized_distances), 4)
        else:
            l1_distance = np.sum(np.abs(xp - x))
            return np.round(l1_distance / xp.shape[-1], 4)

    def evaluate_cost(self, ces):
        cost = 0
        for i, input in enumerate(self.test_inputs):
            cost += self.normalised_l1(ces[i], input)
        return np.round(cost / len(self.test_inputs), 4)

    def evaluate_input_robustness(self, ces, ce_function, name, util_vars):
        # generate new counterfactuals for each input's perturbations.
        avg_dist = 0
        invalid_idx = 0
        for i, x in enumerate(self.test_inputs):
            orig_ces = ces[i] # input's counterfactuals
            perturb_inputs = self.test_rob_inputs[i]
            # generate
            perturb_ces = None
            if name == 'dice':
                perturb_ces = ce_function(perturb_inputs, util_vars['dice_exp'])
            if name == 'icce':
                perturb_ces = ce_function(perturb_inputs, util_vars['X1_class1_clf'], util_vars['tree'], util_vars['clf'])
            elif name == 'nnce':
                perturb_ces = ce_function(perturb_inputs, util_vars['X1_class1_clf'], util_vars['tree'])
            elif name == 'face':
                perturb_ces = ce_function(perturb_inputs, util_vars['face_exp'])
            elif name == 'stce':
                perturb_ces = ce_function(perturb_inputs, util_vars['X1_class1_clf'], util_vars['clf'], util_vars['tree'])
            elif name == 'ours-first-ce':
                perturb_ces = ce_function(perturb_inputs, util_vars['cgmvae'], util_vars['clf'], util_vars['cluster_centroids'], util_vars['device'])[0]
            elif name == 'ours-middle-ce':
                perturb_ces = ce_function(perturb_inputs, util_vars['cgmvae'], util_vars['clf'], util_vars['cluster_centroids'], util_vars['device'])[1]
            elif name == 'ours-last-ce':
                perturb_ces = ce_function(perturb_inputs, util_vars['cgmvae'], util_vars['clf'], util_vars['cluster_centroids'], util_vars['device'])[2]
            # calculate diversity
            if name == 'nnce' or name == 'face' or name == 'stce':
                all_ces = np.concatenate([orig_ces.reshape(1, -1), perturb_ces], axis=0)
                distance_matrix = pairwise_distances(all_ces, metric='manhattan') / all_ces.shape[1]
                upper_triangle_indices = np.triu_indices(all_ces.shape[0], k=1)
                unique_distances = distance_matrix[upper_triangle_indices]
                this_dist = np.mean(unique_distances)
                avg_dist += this_dist
            if name == 'dice' or name == 'icce' or 'ours' in name:
                # ce: 5x23, perturb_ces: 5x5x23
                this_dist = 0
                this_dist_count = 0
                if np.allclose(orig_ces, np.ones((5, self.test_inputs.shape[1])) * -1):
                    invalid_idx += 1
                    continue
                perturb_ces = np.array(perturb_ces)
                all_ces = np.concatenate([orig_ces.reshape(1, orig_ces.shape[0], -1), perturb_ces], axis=0)
                for i1 in range(len(all_ces)):
                    for i2 in range(i1+1, len(all_ces)):
                        this_dist_count += 1
                        this_dist_one = set_distance_max(all_ces[i1], all_ces[i2])
                        this_dist += this_dist_one
                this_dist /= this_dist_count
                avg_dist += this_dist
                
        return np.round((avg_dist / (len(self.test_inputs) - invalid_idx)), 4)
        
    def evaluate_model_robustness(self, ces, rt_clfs, multi=False):
        rob = 0
        if multi:
            for i, xp in enumerate(ces):
                this_ce_frac_valid_m = 1
                for m in rt_clfs:
                    preds = m.predict(xp)
                    if not (preds==1).any():
                        this_ce_frac_valid_m = 0
                        break
                rob += this_ce_frac_valid_m
        else:
            for i, xp in enumerate(ces):
                this_ce_frac_valid_m = 1
                for m in rt_clfs:
                    preds = m.predict(xp.reshape(1, -1))
                    if not (preds==1).any():
                        this_ce_frac_valid_m = 0
                        break
                rob += this_ce_frac_valid_m
        return np.round(rob/len(self.test_inputs), 4)

    def evaluate_diversity(self, ces):
        ces = np.array(ces)
        if ces.ndim != 3:
            return -1
        div = 0
        for _, ce in enumerate(ces):
            p_dist = pdist(ce, metric='cityblock') / ce.shape[-1]
            div += np.mean(p_dist)
        return np.round(div / len(self.test_inputs), 4)            

def evaluate_ces(r_eval, ces, rt_clfs, multi, name='ce', ce_function=None, util_vars=None):
    val = r_eval.evaluate_validity(ces, multi=multi)
    cost = r_eval.evaluate_cost(ces)
    plaus = r_eval.evaluate_plausibility(ces, multi=multi)
    div = r_eval.evaluate_diversity(ces)
    model_rob = r_eval.evaluate_model_robustness(ces, rt_clfs, multi=multi)
    input_rob = r_eval.evaluate_input_robustness(ces, ce_function, name, util_vars)
    return np.array([val, cost, plaus, div, model_rob, input_rob])

def set_distance_max(S1, S2):
    """
    Calculates the set distance between two sets of points S1 and S2
    based on the provided formula.

    Args:
        S1 (np.ndarray): The first set of points, shape (n_samples_1, n_features).
        S2 (np.ndarray): The second set of points, shape (n_samples_2, n_features).

    Returns:
        float: The calculated set distance.
    """
    # 1. Compute the pairwise distance matrix between all points in S1 and S2
    # The result is a matrix of shape (n_samples_1, n_samples_2)
    # where dist_matrix[i, j] is the distance between S1[i] and S2[j].
    dist_matrix = cdist(S1, S2, 'cityblock') / S1.shape[1]

    # 2. Calculate the first term: max_{c1 in S1} min_{c2 in S2} d(c1, c2)
    # For each point in S1 (each row), find the minimum distance to any point in S2
    min_dists_s1_to_s2 = np.min(dist_matrix, axis=1)
    # Then, find the maximum of these minimum distances
    term1 = np.max(min_dists_s1_to_s2)
    
    # 3. Calculate the second term: max_{c2 in S2} min_{c1 in S1} d(c2, c1)
    # For each point in S2 (each column), find the minimum distance to any point in S1
    min_dists_s2_to_s1 = np.min(dist_matrix, axis=0)
    # Then, find the maximum of these minimum distances
    term2 = np.max(min_dists_s2_to_s1)

    distance = 0.5 * (term1 + term2)
    
    return distance

def check_rule_modified(array1, array2, rule_string):
    """
    Checks if two n-d arrays satisfy a given rule string, handling both '>' and '>='.
    Also includes conditional logic for single-variable rules where the RHS is zero.
    """
    if array1.shape != array2.shape:
        raise ValueError("Input arrays must have the same shape.")
        
    # --- 1. MODIFIED: Flexible Parsing for '>' or '>=' ---
    # Split the rule by the operator, keeping the operator itself
    parts = re.split(r'(>=|>)', rule_string)
    if len(parts) != 3:
        raise ValueError(f"Rule '{rule_string}' is malformed. Expected format with '>' or '>='.")
    
    lhs_str, operator, rhs_str = parts
    rhs_str_stripped = rhs_str.strip()
    rhs = float(rhs_str_stripped)

    # --- The rest of the logic remains largely the same ---
    variable_terms = re.findall(r'([+\-]\s*y_(\d+))', lhs_str)

    if not (1 <= len(variable_terms) <= 2):
        raise ValueError(f"Rule must contain 1 or 2 variables. Found {len(variable_terms)}.")
    
    term1_str, index1_str = variable_terms[0]
    index1 = int(index1_str)
    
    if not 0 <= index1 < len(array1):
        raise IndexError(f"Index {index1} from rule is out of bounds for the arrays (length {len(array1)}).")
        
    if len(variable_terms) == 1 and rhs == 0.0:
        if '-' in rhs_str_stripped:
            rhs = -1 * array2[index1]
        else:
            rhs = array2[index1]
        
    lhs_value = 0.0
    value1 = array1[index1]
    lhs_value += -value1 if '-' in term1_str else value1

    if len(variable_terms) == 2:
        term2_str, index2_str = variable_terms[1]
        index2 = int(index2_str)
        if not 0 <= index2 < len(array2):
            raise IndexError(f"Index {index2} is out of bounds for array2 (length {len(array2)}).")
        value2 = array2[index2]
        lhs_value += -value2 if '-' in term2_str else value2
        
    if operator == '>=':
        return lhs_value >= rhs
    else:  # operator must be '>'
        return lhs_value > rhs

def enforce_constraints(ce_array, reference_array, rules_list):
    """
    Manually modifies a counterfactual array to satisfy a list of constraints.

    Args:
        ce_array (np.ndarray): The counterfactual(s) to modify. Can be 1D or 2D.
        reference_array (np.ndarray): The reference data point (array2 in the rules).
        rules_list (list): A list of rule strings to enforce.

    Returns:
        np.ndarray: The modified counterfactual array with constraints enforced.
    """

    was_1d = False
    if ce_array.ndim == 1:
        was_1d = True
        ce_array = ce_array.reshape(1, -1)
    
    modified_ce = ce_array.copy()

    # Iterate over each counterfactual point to correct it 
    for i in range(modified_ce.shape[0]):
        ce_point = modified_ce[i]

        # Iterate over each rule to enforce it 
        for rule_string in rules_list:
            # Parse the rule to get LHS, operator, and RHS
            parts = re.split(r'(>=|>)', rule_string)
            if len(parts) != 3: continue
            lhs_str, operator, rhs_str = parts
            rhs_str_stripped = rhs_str.strip()
            rhs = float(rhs_str_stripped)
            variable_terms = re.findall(r'([+\-]\s*y_(\d+))', lhs_str)
            if not (1 <= len(variable_terms) <= 2): continue

            # replace RHS for single-variable, zero-RHS rules
            term1_str, index1_str = variable_terms[0]
            index_to_change = int(index1_str)
            if len(variable_terms) == 1 and rhs == 0.0:
                rhs = -reference_array[index_to_change] if '-' in rhs_str_stripped else reference_array[index_to_change]

            # Calculate LHS
            lhs_value = 0.0
            value1 = ce_point[index_to_change]
            lhs_value += -value1 if '-' in term1_str else value1
            if len(variable_terms) == 2:
                _, index2_str = variable_terms[1]
                value2 = reference_array[int(index2_str)]
                lhs_value += -value2 if '-' in variable_terms[1][0] else value2
            
            # Check rule
            is_satisfied = (lhs_value >= rhs) if operator == '>=' else (lhs_value > rhs)

            if not is_satisfied:
                deficit = rhs - lhs_value
                if operator == '>':
                    deficit += 1e-6 

                # Apply correction
                if '-' in term1_str:
                    ce_point[index_to_change] -= deficit
                else:
                    ce_point[index_to_change] += deficit

    # Return the result in its original shape
    return modified_ce.squeeze(0) if was_1d else modified_ce

def evaluate_actionability(X_test_ces, path_ces, rule_lists, clf, enforce=True, n_clusters=5):
    valids = np.zeros((len(X_test_ces), path_ces.shape[-2]))
    constraint_satisfied = np.zeros((len(X_test_ces), path_ces.shape[-2]))
    if path_ces.ndim == 4:
        valids = np.zeros((len(X_test_ces), n_clusters, path_ces.shape[-2]))
        constraint_satisfied = np.zeros((len(X_test_ces), n_clusters, path_ces.shape[-2]))
    for i, path in enumerate(path_ces):
        rule_list = rule_lists[i]
        # one path
        if path.ndim == 2:
            # enforce constraint on path points
            for j, ce_point in enumerate(path):
                if enforce:
                    path[j] = enforce_constraints(ce_point, X_test_ces[i], rule_list)
                constraint_satisfied[i, j] = 1
                for rule_str in rule_list:
                    if not check_rule_modified(path[j], X_test_ces[i], rule_str):
                        constraint_satisfied[i, j] = 0
                        break
            valids[i] = clf.predict(path)
        elif path.ndim == 3:
            for j, ce_path in enumerate(path):
                for k, ce_point in enumerate(ce_path):
                    if enforce:
                        path[j, k] = enforce_constraints(ce_point, X_test_ces[i], rule_list)
                    constraint_satisfied[i, j, k] = 1
                    for rule_str in rule_list:
                        if not check_rule_modified(path[j, k], X_test_ces[i], rule_str):
                            constraint_satisfied[i, j, k] = 0
                            break
                valids[i, j] = clf.predict(ce_path)
    # evaluate
    valid_and_constraint_satisfied = np.zeros((len(X_test_ces)))
    for i in range(len(X_test_ces)):
        valid_and_constraint_satisfied[i] = int(np.logical_and(valids[i], valids[i]).any())
    return np.round(np.sum(valid_and_constraint_satisfied) / len(X_test_ces), 4), valids, constraint_satisfied

