
from sklearn.tree import DecisionTreeClassifier, export_text, export_graphviz
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
import re
import os
import hashlib

def coin_toss(length, seed=None):
    # Party 1 samples r1 and computes h1
    if seed is not None:
        rng = np.random.default_rng(seed)
        r1 = rng.integers(0, 256, size=length, dtype=np.uint8).tobytes()
        r2 = rng.integers(0, 256, size=length, dtype=np.uint8).tobytes()
    else:
        r1 = os.urandom(length)
        r2 = os.urandom(length)
    h1 = hashlib.sha3_256(r1).digest()

    # Party 1 sends h1 to Party 2 (simulate by passing variable)
    # Party 2 receives h1, samples r2 and computes h2
    h2 = hashlib.sha3_256(r2).digest()

    # Party 2 sends h2 to Party 1 (simulate by passing variable)
    # Party 1 receives h2, sends r1 to Party 2
    # Party 2 receives r1, checks h1 == H(r1)
    if h1 != hashlib.sha3_256(r1).digest():
        raise ValueError("Party 2: h1 does not match H(r1), aborting.")

    # Party 2 sends r2 to Party 1
    # Party 1 receives r2, checks h2 == H(r2)
    if h2 != hashlib.sha3_256(r2).digest():
        raise ValueError("Party 1: h2 does not match H(r2), aborting.")

    # Output r = r1 xor r2
    r = bytes(a ^ b for a, b in zip(r1, r2))
    return r

# === Fairness Metrics ===
def compute_dp(preds, sensitive, sensitive_values=(1, 2)):
    rate_group0 = preds[sensitive == sensitive_values[0]].mean()
    rate_group1 = preds[sensitive == sensitive_values[1]].mean()

    return abs(rate_group0 - rate_group1), rate_group0, rate_group1

def compute_eqod(y_true, preds, sensitive, sensitive_values=(1, 2)):
    """
    Compute Equalized Odds (difference in TPR and FPR between groups).
    Use TPR for equal opportunity.
    """
    mask_group0 = (sensitive == sensitive_values[0])
    mask_group1 = (sensitive == sensitive_values[1])

    # True Positive Rate
    tpr_group0 = ((preds == 1) & (y_true == 1) & mask_group0).sum() / ((y_true == 1) & mask_group0).sum()
    tpr_group1 = ((preds == 1) & (y_true == 1) & mask_group1).sum() / ((y_true == 1) & mask_group1).sum()

    # False Positive Rate
    fpr_group0 = ((preds == 1) & (y_true == 0) & mask_group0).sum() / ((y_true == 0) & mask_group0).sum()
    fpr_group1 = ((preds == 1) & (y_true == 0) & mask_group1).sum() / ((y_true == 0) & mask_group1).sum()

    return (abs(tpr_group0 - tpr_group1), abs(fpr_group0 - fpr_group1)), (tpr_group0, tpr_group1), (fpr_group0, fpr_group1)

def compute_mrd(y_true, y_proba, sensitive, sensitive_values=(1, 2)):
    """
    Compute Mean Residual Difference.
    """
    residuals = y_true - y_proba
    r_group0 = residuals[sensitive == sensitive_values[0]].mean()
    r_group1 = residuals[sensitive == sensitive_values[1]].mean()
    
    return abs(r_group0 - r_group1), r_group0, r_group1

def welch_residual_ttest(alpha, y_test, y_proba, A_test, model_type, scale=1, sensitive_values=(1, 2)):
    """
    Perform Welch's t-test for the difference in mean residuals between two groups in A_test.
    Residuals are computed as y_test - y_proba, both scaled by `scale`.
    Returns: t_stat, critical_value, p_value, means, variances, n1, n2
    """
    y_test_scaled = y_test * scale
    y_proba_scaled = np.round(y_proba * scale).astype(int)
    residuals_group0 = y_test_scaled[A_test == sensitive_values[0]] - y_proba_scaled[A_test == sensitive_values[0]]
    residuals_group1 = y_test_scaled[A_test == sensitive_values[1]] - y_proba_scaled[A_test == sensitive_values[1]]
    n1 = residuals_group0.shape[0]
    n2 = residuals_group1.shape[0]
    s1_sq = np.var(residuals_group0, ddof=1)
    s2_sq = np.var(residuals_group1, ddof=1)
    t_stat, p_value = stats.ttest_ind(residuals_group0, residuals_group1, equal_var=False)
    df = (s1_sq/n1 + s2_sq/n2)**2 / ((s1_sq/n1)**2/(n1-1) + (s2_sq/n2)**2/(n2-1))
    critical_value = stats.t.ppf(1 - alpha/2, df)
    means = (residuals_group0.mean(), residuals_group1.mean())
    variances = (s1_sq, s2_sq)

    print(f"=== Residual Means and Variances ({model_type}, scaled by {scale}) ===")
    print(f"Group {sensitive_values[0]} mean: {means[0]:.4f}, variance: {variances[0]:.4f}")
    print(f"Group {sensitive_values[1]} mean: {means[1]:.4f}, variance: {variances[1]:.4f}")

    print(f"T-test (unequal var) between two residuals: t={t_stat:.4f}, critical value={critical_value:.4f}")
    if abs(t_stat) > critical_value:
        print("Difference in means is statistically significant at alpha=0.05\n")
    else:
        print("No statistically significant difference in means at alpha=0.05\n")

    return t_stat, critical_value, p_value, means, variances, n1, n2

def print_sample_and_decision_path(clf, X_samples, feature_names, sample_idx=1):
    """
    Print the test sample at sample_idx, with feature names replaced by their index,
    and print the decision path for this sample in the given classifier.
    """
    sample = X_samples.iloc[sample_idx]
    for idx, value in enumerate(sample):
        print(f"{idx}: {value}")
    # Store the decision path (as a boolean matrix) for each test sample in the model
    decision_path = clf.decision_path(X_samples)
    # Get the decision path for the data point at position sample_idx in X_samples
    sample_path = decision_path[sample_idx]
    # Get the node indices traversed by this sample
    node_indices = sample_path.indices
    # Get the tree object
    tree = clf.tree_
    # Build a readable path
    path_str = []
    for node_id in node_indices:
        if tree.children_left[node_id] == tree.children_right[node_id]:  # leaf node
            path_str.append(f"LEAF {node_id}")
        else:
            feature = tree.feature[node_id]
            thresh = tree.threshold[node_id]
            path_str.append(f"{feature_names[feature]} {'<=' if sample[feature] <= thresh else '>'} {thresh:.2f}")
    print(" -> ".join(path_str))

def export_leaf_labels(clf, feature_names=None, decimals=1, spacing=1):
    """
    Export decision tree with:
    - Internal node thresholds scaled to integers by 10^decimals (including negative values)
    - Leaf node class labels (majority class) instead of probabilities
    - Class labels are not scaled
    """
    scale = 10 ** decimals
    text = export_text(clf, feature_names=feature_names, spacing=spacing, decimals=decimals)
    lines = text.splitlines()
    modified_lines = []

    for line in lines:
        # --- Process thresholds (<=, including negative values) ---
        threshold_match = re.search(r'<=\s*(-?[\d.]+)', line)
        if threshold_match:
            threshold = float(threshold_match.group(1))
            scaled_threshold = int(round(threshold * scale))
            line = re.sub(r'<=\s*-?[\d.]+', f'<= {scaled_threshold}', line)

        # --- Process thresholds (> , including negative values) ---
        threshold_match_g = re.search(r'>\s*(-?[\d.]+)', line)
        if threshold_match_g:
            threshold = float(threshold_match_g.group(1))
            scaled_threshold = int(round(threshold * scale))
            line = re.sub(r'>\s*-?[\d.]+', f'> {scaled_threshold}', line)

        modified_lines.append(line)

    return "\n".join(modified_lines)

# === Tree Export with Probabilities ===
def export_leaf_probabilities(clf, feature_names=None, decimals=1, spacing=1):
    """
    Export decision tree with:
    - Internal node thresholds scaled to integers by 10^decimals (including negative values)
    - Leaf node class probabilities (for class 1) scaled to integers by 10^decimals
    - Probabilities are scaled by 10^decimals
    """
    scale = 10 ** decimals
    text = export_text(clf, feature_names=feature_names, show_weights=True, spacing=spacing, decimals=decimals)
    lines = text.splitlines()
    modified_lines = []

    for line in lines:
        # --- Process thresholds (<=, including negative values) ---
        threshold_match = re.search(r'<=\s*(-?[\d.]+)', line)
        if threshold_match:
            threshold = float(threshold_match.group(1))
            scaled_threshold = int(round(threshold * scale))
            line = re.sub(r'<=\s*-?[\d.]+', f'<= {scaled_threshold}', line)

        # --- Process thresholds (> , including negative values) ---
        threshold_match_g = re.search(r'>\s*(-?[\d.]+)', line)
        if threshold_match_g:
            threshold = float(threshold_match_g.group(1))
            scaled_threshold = int(round(threshold * scale))
            line = re.sub(r'>\s*-?[\d.]+', f'> {scaled_threshold}', line)

        # --- Process leaf class probability ---
        if "class:" in line and "weights:" in line:
            weights_match = re.search(r'weights: \[([^\]]+)\]', line)
            if weights_match:
                counts = np.array([float(x.strip()) for x in weights_match.group(1).split(',')])
                probs = counts / counts.sum() if counts.sum() > 0 else np.zeros_like(counts)
                prob_class_1 = probs[1] if len(probs) > 1 else 0.0
                prob_scaled = int(round(prob_class_1 * scale))
                line = re.sub(r'class: \d+', f'class: {prob_scaled}', line)
            line = re.sub(r'weights: \[[^\]]+\]', '', line).rstrip()

        modified_lines.append(line)

    return "\n".join(modified_lines)

