"""
TRAK-based Membership Inference Attack

A minimal implementation for ICLR paper submission demonstrating
TRAK gradient-based membership inference attacks on data curation.

This code illustrates the core attack methodology without production dependencies.
"""

import numpy as np
import jax
import jax.numpy as jnp
from typing import Dict, Tuple, Optional
from sklearn.metrics import roc_curve, auc


def compute_xtx_inverse(gradients: jnp.ndarray, regularization: float = 1e-2) -> jnp.ndarray:
    """
    Compute regularized inverse of gradient covariance matrix.

    Args:
        gradients: Gradient matrix (n_samples, d_features)
        regularization: Ridge regularization parameter

    Returns:
        Inverse covariance matrix (d_features, d_features)
    """
    XtX = jnp.dot(gradients.T, gradients)
    regularized_matrix = XtX + regularization * jnp.eye(XtX.shape[0])
    return jnp.linalg.inv(regularized_matrix)


def compute_contrastive_scores(
    target_gradients: jnp.ndarray,
    contrastive_gradients: jnp.ndarray,
    xtx_inv: jnp.ndarray
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Compute TRAK contrastive scores for membership inference.

    Args:
        target_gradients: Target dataset gradients (n_targets, d)
        contrastive_gradients: Contrastive gradients (n_contrastive, d)
        xtx_inv: Inverse covariance matrix (d, d)

    Returns:
        Tuple of (scores_h0, scores_h1) for null and alternative hypotheses
    """
    n_targets = target_gradients.shape[0]

    # Compute mean target gradient
    target_mean = jnp.mean(target_gradients, axis=0)

    # Scores under H0 (target not in dataset)
    # Use leave-one-out mean for each target
    scores_h0 = []
    for i in range(n_targets):
        # Leave-one-out mean
        h0_mean = (jnp.sum(target_gradients, axis=0) - target_gradients[i]) / (n_targets - 1)
        # Score contrastive gradients against this mean
        h0_scores = jnp.dot(contrastive_gradients, jnp.dot(xtx_inv, h0_mean))
        scores_h0.append(h0_scores)

    scores_h0 = jnp.array(scores_h0)  # Shape: (n_targets, n_contrastive)

    # Scores under H1 (target in dataset)
    # Use full mean including each target
    scores_h1 = jnp.dot(contrastive_gradients, jnp.dot(xtx_inv, target_mean))
    scores_h1 = jnp.tile(scores_h1, (n_targets, 1))  # Shape: (n_targets, n_contrastive)

    return scores_h0, scores_h1


def find_optimal_targets(
    scores_h0: jnp.ndarray,
    scores_h1: jnp.ndarray,
    selection_threshold: float = 75.0,
    max_targets: int = 100
) -> jnp.ndarray:
    """
    Find targets that are most vulnerable to membership inference.

    Args:
        scores_h0: Scores under null hypothesis (n_targets, n_contrastive)
        scores_h1: Scores under alternative hypothesis (n_targets, n_contrastive)
        selection_threshold: Percentile threshold for curation selection
        max_targets: Maximum number of targets to select

    Returns:
        Indices of most vulnerable targets
    """
    n_targets, n_contrastive = scores_h0.shape

    # For each target, compute how many contrastive samples change selection status
    vulnerability_scores = []

    for i in range(n_targets):
        # Convert scores to percentiles
        h0_percentiles = (jnp.argsort(jnp.argsort(scores_h0[i])) / n_contrastive) * 100
        h1_percentiles = (jnp.argsort(jnp.argsort(scores_h1[i])) / n_contrastive) * 100

        # Count samples that cross the threshold
        h0_selected = h0_percentiles >= selection_threshold
        h1_selected = h1_percentiles >= selection_threshold

        # Vulnerability is the number of samples that change selection status
        status_changes = jnp.sum(h0_selected != h1_selected)
        vulnerability_scores.append(status_changes)

    vulnerability_scores = jnp.array(vulnerability_scores)

    # Select most vulnerable targets
    n_select = min(max_targets, n_targets)
    top_indices = jnp.argsort(vulnerability_scores)[-n_select:]

    return top_indices


def compute_membership_signal(
    selected_contrastive_indices: jnp.ndarray,
    scores_h0: jnp.ndarray,
    scores_h1: jnp.ndarray,
    target_indices: jnp.ndarray,
    selection_threshold: float = 75.0
) -> jnp.ndarray:
    """
    Compute membership inference signal based on curation selection.

    Args:
        selected_contrastive_indices: Indices of contrastive samples selected by curation
        scores_h0: Scores under null hypothesis
        scores_h1: Scores under alternative hypothesis
        target_indices: Indices of targets to analyze
        selection_threshold: Curation selection threshold percentile

    Returns:
        Membership signals for each target (higher = more likely member)
    """
    n_contrastive = scores_h0.shape[1]
    selected_set = set(selected_contrastive_indices)

    membership_signals = []

    for target_idx in target_indices:
        # Get scores for this target
        h0_scores = scores_h0[target_idx]
        h1_scores = scores_h1[target_idx]

        # Convert to percentiles
        h0_percentiles = (jnp.argsort(jnp.argsort(h0_scores)) / n_contrastive) * 100
        h1_percentiles = (jnp.argsort(jnp.argsort(h1_scores)) / n_contrastive) * 100

        # Compute expected vs observed selection
        signal = 0.0
        n_observations = 0

        for i in range(n_contrastive):
            # Expected selection probabilities based on percentiles
            p_h0 = 1.0 / (1.0 + jnp.exp(-(h0_percentiles[i] - selection_threshold) / 10))
            p_h1 = 1.0 / (1.0 + jnp.exp(-(h1_percentiles[i] - selection_threshold) / 10))

            # Observed selection
            was_selected = i in selected_set

            # Likelihood ratio contribution
            if was_selected:
                signal += jnp.log(p_h1 + 1e-8) - jnp.log(p_h0 + 1e-8)
            else:
                signal += jnp.log(1 - p_h1 + 1e-8) - jnp.log(1 - p_h0 + 1e-8)

            n_observations += 1

        # Average signal
        membership_signals.append(signal / n_observations)

    return jnp.array(membership_signals)


def evaluate_trak_attack(
    membership_signals: jnp.ndarray,
    ground_truth: jnp.ndarray
) -> Dict:
    """
    Evaluate TRAK attack performance.

    Args:
        membership_signals: Membership inference signals
        ground_truth: Binary labels (1=member, 0=non-member)

    Returns:
        Dictionary with evaluation metrics
    """
    # Compute ROC curve
    fpr, tpr, thresholds = roc_curve(ground_truth, membership_signals)
    roc_auc = auc(fpr, tpr)

    # Compute statistics
    member_signals = membership_signals[ground_truth == 1]
    nonmember_signals = membership_signals[ground_truth == 0]

    return {
        "roc_auc": roc_auc,
        "fpr": fpr,
        "tpr": tpr,
        "thresholds": thresholds,
        "member_signal_mean": float(jnp.mean(member_signals)) if len(member_signals) > 0 else 0.0,
        "member_signal_std": float(jnp.std(member_signals)) if len(member_signals) > 0 else 0.0,
        "nonmember_signal_mean": float(jnp.mean(nonmember_signals)) if len(nonmember_signals) > 0 else 0.0,
        "nonmember_signal_std": float(jnp.std(nonmember_signals)) if len(nonmember_signals) > 0 else 0.0,
        "n_members": int(jnp.sum(ground_truth)),
        "n_nonmembers": int(jnp.sum(ground_truth == 0))
    }


def run_trak_attack(
    target_gradients: jnp.ndarray,
    contrastive_gradients: jnp.ndarray,
    victim_indices: jnp.ndarray,
    selected_contrastive_indices: jnp.ndarray,
    regularization: float = 1e-2,
    selection_threshold: float = 75.0,
    max_targets: int = 100
) -> Dict:
    """
    Run the complete TRAK-based membership inference attack.

    Args:
        target_gradients: Target dataset gradients (n_targets, d)
        contrastive_gradients: Contrastive gradients for attack (n_contrastive, d)
        victim_indices: Indices of targets that are members
        selected_contrastive_indices: Contrastive samples selected by curation
        regularization: Ridge regularization parameter
        selection_threshold: Curation selection threshold percentile
        max_targets: Maximum targets to analyze

    Returns:
        Dictionary with attack results and evaluation
    """
    # Step 1: Compute inverse covariance matrix
    all_gradients = jnp.concatenate([target_gradients, contrastive_gradients], axis=0)
    xtx_inv = compute_xtx_inverse(all_gradients, regularization)

    # Step 2: Compute contrastive scores under both hypotheses
    scores_h0, scores_h1 = compute_contrastive_scores(
        target_gradients, contrastive_gradients, xtx_inv
    )

    # Step 3: Find most vulnerable targets
    vulnerable_targets = find_optimal_targets(
        scores_h0, scores_h1, selection_threshold, max_targets
    )

    # Step 4: Compute membership signals for vulnerable targets
    membership_signals_all = jnp.zeros(len(target_gradients))
    membership_signals_vulnerable = compute_membership_signal(
        selected_contrastive_indices, scores_h0, scores_h1,
        vulnerable_targets, selection_threshold
    )

    # Update signals for vulnerable targets only
    membership_signals_all = membership_signals_all.at[vulnerable_targets].set(
        membership_signals_vulnerable
    )

    # Step 5: Create ground truth labels
    ground_truth = jnp.zeros(len(target_gradients), dtype=int)
    ground_truth = ground_truth.at[victim_indices].set(1)

    # Step 6: Evaluate attack
    evaluation = evaluate_trak_attack(membership_signals_all, ground_truth)

    return {
        "membership_signals": membership_signals_all,
        "ground_truth": ground_truth,
        "vulnerable_targets": vulnerable_targets,
        "scores_h0": scores_h0,
        "scores_h1": scores_h1,
        "evaluation": evaluation
    }


def example_usage():
    """Example usage with synthetic gradients."""
    np.random.seed(42)

    # Create synthetic gradients
    d = 512  # gradient dimension
    n_targets = 1000
    n_contrastive = 5000
    n_victims = 500  # 50% are victims

    # Generate gradients (normalized)
    target_gradients = np.random.randn(n_targets, d)
    target_gradients = target_gradients / np.linalg.norm(target_gradients, axis=1, keepdims=True)
    target_gradients = jnp.array(target_gradients)

    contrastive_gradients = np.random.randn(n_contrastive, d)
    contrastive_gradients = contrastive_gradients / np.linalg.norm(contrastive_gradients, axis=1, keepdims=True)
    contrastive_gradients = jnp.array(contrastive_gradients)

    # Create victims (first half are members)
    victim_indices = jnp.arange(n_victims)

    # Simulate curation: select top 25% of contrastive based on similarity to target mean
    target_mean = jnp.mean(target_gradients, axis=0)
    similarities = jnp.dot(contrastive_gradients, target_mean)
    n_selected = int(0.25 * n_contrastive)
    selected_contrastive_indices = jnp.argsort(similarities)[-n_selected:]

    # Run attack
    result = run_trak_attack(
        target_gradients=target_gradients,
        contrastive_gradients=contrastive_gradients,
        victim_indices=victim_indices,
        selected_contrastive_indices=selected_contrastive_indices,
        regularization=1e-2,
        selection_threshold=75.0,
        max_targets=100
    )

    # Print results
    eval_result = result["evaluation"]
    print(f"TRAK-based MIA Attack Results:")
    print(f"  ROC AUC: {eval_result['roc_auc']:.4f}")
    print(f"  Members: {eval_result['n_members']}")
    print(f"  Non-members: {eval_result['n_nonmembers']}")
    print(f"  Vulnerable targets found: {len(result['vulnerable_targets'])}")
    print(f"  Member signal: {eval_result['member_signal_mean']:.4f} ± {eval_result['member_signal_std']:.4f}")
    print(f"  Non-member signal: {eval_result['nonmember_signal_mean']:.4f} ± {eval_result['nonmember_signal_std']:.4f}")

    return result


if __name__ == "__main__":
    example_usage()