"""
Membership Inference Attack via Correspondence (MIA-C)

A minimal implementation for ICLR paper submission demonstrating
correspondence-based membership inference attacks on data curation.

This code illustrates the core attack methodology without production dependencies.
"""

import numpy as np
from typing import Dict, Tuple, Optional
from sklearn.metrics import roc_curve, auc


def find_optimal_correspondences(
    target_embeddings: np.ndarray,
    pool_embeddings: np.ndarray,
    k_candidates: int = 50,
    alpha: float = 0.5,
    use_gpu: bool = False
) -> Dict:
    """
    Find optimal correspondences between target and pool embeddings using FAISS.

    Args:
        target_embeddings: Target dataset embeddings (n_targets, d)
        pool_embeddings: Pool candidate embeddings (n_pool, d)
        k_candidates: Number of nearest neighbors to consider per target
        alpha: Balance between attraction (α) and repulsion (1-α)
        use_gpu: Use GPU acceleration if available

    Returns:
        Dict with correspondences, scores, and metrics
    """
    try:
        import faiss
    except ImportError:
        raise ImportError("FAISS is required but not installed")

    n_targets, d = target_embeddings.shape
    n_pool = len(pool_embeddings)

    # Build FAISS index
    if use_gpu and faiss.get_num_gpus() > 0:
        index = faiss.IndexFlatIP(d)
        index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, index)
    else:
        index = faiss.IndexFlatIP(d)

    index.add(pool_embeddings.astype(np.float32))

    # Find k nearest neighbors for each target
    similarities, neighbors = index.search(target_embeddings.astype(np.float32), k_candidates)

    # Compute attraction scores (similarity to each target's neighbors)
    attraction_scores = np.zeros(n_pool)
    for i in range(n_targets):
        for j, pool_idx in enumerate(neighbors[i]):
            attraction_scores[pool_idx] += similarities[i, j]

    # Compute repulsion scores (average similarity to all other targets)
    all_similarities, _ = index.search(target_embeddings.astype(np.float32), n_pool)
    repulsion_scores = np.mean(all_similarities, axis=0)

    # Combined correspondence score: α * attraction + (1-α) * (-repulsion)
    correspondence_scores = alpha * attraction_scores + (1 - alpha) * (-repulsion_scores)

    # Select best correspondence for each target
    correspondences = np.zeros(n_targets, dtype=int)
    selected_scores = np.zeros(n_targets)

    for i in range(n_targets):
        # Among this target's k candidates, pick the one with highest overall score
        candidate_indices = neighbors[i]
        candidate_scores = correspondence_scores[candidate_indices]
        best_idx = np.argmax(candidate_scores)
        correspondences[i] = candidate_indices[best_idx]
        selected_scores[i] = candidate_scores[best_idx]

    return {
        "correspondences": correspondences,
        "scores": selected_scores,
        "attraction_scores": attraction_scores,
        "repulsion_scores": repulsion_scores,
        "unique_correspondences": len(np.unique(correspondences)),
        "total_candidates": k_candidates * n_targets
    }


def compute_baseline_rankings(pool_embeddings: np.ndarray, target_embeddings: np.ndarray) -> np.ndarray:
    """
    Compute baseline percentile rankings for pool samples against targets.

    Args:
        pool_embeddings: Pool embeddings (n_pool, d)
        target_embeddings: Target embeddings (n_targets, d)

    Returns:
        Baseline percentile rankings for each pool sample (0-100)
    """
    # Compute similarities: targets @ pool.T -> (n_targets, n_pool)
    similarities = np.dot(target_embeddings, pool_embeddings.T)

    # Take max similarity across all targets for each pool sample
    max_similarities = np.max(similarities, axis=0)  # Shape: (n_pool,)

    # Convert to percentile rankings
    from scipy import stats
    baseline_percentiles = stats.rankdata(max_similarities, method="average") / len(max_similarities) * 100

    return baseline_percentiles


def select_marked_samples(
    correspondences: np.ndarray,
    baseline_percentiles: np.ndarray,
    max_marked: int = 1000,
    selection_rate: float = 0.25
) -> np.ndarray:
    """
    Select pool samples to mark for the attack based on crossing potential.

    Args:
        correspondences: Pool indices that correspond to targets
        baseline_percentiles: Baseline percentile rankings
        max_marked: Maximum number of samples to mark
        selection_rate: Expected curation selection rate (e.g., 0.25 for top 25%)

    Returns:
        Array of pool indices to mark for monitoring
    """
    selection_threshold = (1.0 - selection_rate) * 100  # e.g., 75th percentile

    # Score each correspondence by its crossing potential
    unique_correspondences = np.unique(correspondences)
    pool_scores = []

    for pool_idx in unique_correspondences:
        baseline_pct = baseline_percentiles[pool_idx]

        # Count how many targets use this correspondence (uniqueness)
        n_targets_sharing = np.sum(correspondences == pool_idx)
        uniqueness_score = 1.0 / n_targets_sharing

        # Crossing potential: how likely to cross threshold when target included
        if baseline_pct < selection_threshold:
            distance_from_threshold = selection_threshold - baseline_pct
            crossing_score = 1.0 / (1.0 + np.exp((distance_from_threshold - 10) / 5))
        else:
            crossing_score = 0.1  # Already above threshold

        info_score = crossing_score * uniqueness_score
        pool_scores.append((pool_idx, info_score))

    # Select top samples by information score
    pool_scores.sort(key=lambda x: x[1], reverse=True)
    n_select = min(max_marked, len(pool_scores))
    marked_samples = np.array([pool_idx for pool_idx, _ in pool_scores[:n_select]])

    return marked_samples


def compute_mia_scores(
    selected_pool_indices: np.ndarray,
    correspondences: np.ndarray,
    baseline_percentiles: np.ndarray,
    marked_samples: np.ndarray,
    selection_rate: float = 0.25
) -> np.ndarray:
    """
    Compute MIA scores based on selection surprise.

    Args:
        selected_pool_indices: Pool indices that were selected during curation
        correspondences: Correspondence mapping (target_idx -> pool_idx)
        baseline_percentiles: Baseline percentile rankings for pool samples
        marked_samples: Pool indices that were marked for monitoring
        selection_rate: Expected selection rate

    Returns:
        MIA scores for each target (higher = more likely to be member)
    """
    n_targets = len(correspondences)
    selected_set = set(selected_pool_indices)
    selection_threshold = (1.0 - selection_rate) * 100

    def expected_selection_prob(baseline_pct, threshold=selection_threshold, temperature=10):
        """Sigmoid probability model: P(selected | baseline percentile)"""
        return 1.0 / (1.0 + np.exp(-(baseline_pct - threshold) / temperature))

    mia_scores = np.zeros(n_targets)

    for target_idx in range(n_targets):
        pool_idx = correspondences[target_idx]

        # Only compute scores for marked samples
        if pool_idx not in marked_samples:
            mia_scores[target_idx] = 0.0
            continue

        # Get baseline and selection outcome
        baseline_pct = baseline_percentiles[pool_idx]
        was_selected = pool_idx in selected_set

        # Compute expected selection probability
        expected_prob = expected_selection_prob(baseline_pct)

        # Compute surprise: observed - expected
        if was_selected:
            surprise = 1.0 - expected_prob  # High surprise if low baseline selected
        else:
            surprise = 0.0 - expected_prob  # Negative surprise if high baseline not selected

        # Weight by uniqueness (divide by number of targets sharing this correspondence)
        n_targets_sharing = np.sum(correspondences == pool_idx)
        mia_scores[target_idx] = surprise / n_targets_sharing

    return mia_scores


def evaluate_attack(mia_scores: np.ndarray, ground_truth: np.ndarray) -> Dict:
    """
    Evaluate MIA attack using ROC analysis.

    Args:
        mia_scores: MIA scores for all targets
        ground_truth: Binary labels (1=member, 0=non-member)

    Returns:
        Dictionary with evaluation metrics
    """
    # Compute ROC curve
    fpr, tpr, thresholds = roc_curve(ground_truth, mia_scores)
    roc_auc = auc(fpr, tpr)

    # Compute statistics
    member_scores = mia_scores[ground_truth == 1]
    nonmember_scores = mia_scores[ground_truth == 0]

    return {
        "roc_auc": roc_auc,
        "fpr": fpr,
        "tpr": tpr,
        "thresholds": thresholds,
        "member_score_mean": np.mean(member_scores) if len(member_scores) > 0 else 0.0,
        "member_score_std": np.std(member_scores) if len(member_scores) > 0 else 0.0,
        "nonmember_score_mean": np.mean(nonmember_scores) if len(nonmember_scores) > 0 else 0.0,
        "nonmember_score_std": np.std(nonmember_scores) if len(nonmember_scores) > 0 else 0.0,
        "n_members": np.sum(ground_truth),
        "n_nonmembers": np.sum(ground_truth == 0)
    }


def run_correspondence_attack(
    target_embeddings: np.ndarray,
    pool_embeddings: np.ndarray,
    victim_indices: np.ndarray,
    selected_pool_indices: np.ndarray,
    alpha: float = 0.5,
    k_candidates: int = 50,
    selection_rate: float = 0.25
) -> Dict:
    """
    Run the complete correspondence-based MIA attack.

    Args:
        target_embeddings: Full target dataset embeddings (n_targets, d)
        pool_embeddings: Pool candidate embeddings (n_pool, d)
        victim_indices: Indices of targets that are members (victims)
        selected_pool_indices: Pool indices selected by curation
        alpha: Correspondence balance parameter
        k_candidates: Number of candidates per target
        selection_rate: Curation selection rate

    Returns:
        Dictionary with attack results and evaluation
    """
    # Step 1: Find correspondences
    corr_result = find_optimal_correspondences(
        target_embeddings, pool_embeddings, k_candidates, alpha
    )
    correspondences = corr_result["correspondences"]

    # Step 2: Compute baseline rankings
    baseline_percentiles = compute_baseline_rankings(pool_embeddings, target_embeddings)

    # Step 3: Select samples to mark
    marked_samples = select_marked_samples(
        correspondences, baseline_percentiles, selection_rate=selection_rate
    )

    # Step 4: Compute MIA scores
    mia_scores = compute_mia_scores(
        selected_pool_indices, correspondences, baseline_percentiles,
        marked_samples, selection_rate
    )

    # Step 5: Create ground truth labels
    ground_truth = np.zeros(len(target_embeddings), dtype=int)
    ground_truth[victim_indices] = 1

    # Step 6: Evaluate attack
    evaluation = evaluate_attack(mia_scores, ground_truth)

    return {
        "correspondences": correspondences,
        "mia_scores": mia_scores,
        "ground_truth": ground_truth,
        "baseline_percentiles": baseline_percentiles,
        "marked_samples": marked_samples,
        "evaluation": evaluation,
        "correspondence_stats": corr_result
    }


def example_usage():
    """Example usage with synthetic data."""
    np.random.seed(42)

    # Create synthetic embeddings
    d = 512  # embedding dimension
    n_targets = 1000
    n_pool = 10000
    n_victims = 500  # 50% are victims

    # Generate embeddings
    target_embeddings = np.random.randn(n_targets, d)
    target_embeddings = target_embeddings / np.linalg.norm(target_embeddings, axis=1, keepdims=True)

    pool_embeddings = np.random.randn(n_pool, d)
    pool_embeddings = pool_embeddings / np.linalg.norm(pool_embeddings, axis=1, keepdims=True)

    # Create victims (first half are members)
    victim_indices = np.arange(n_victims)

    # Simulate curation: select top 25% of pool based on max similarity
    similarities = np.dot(target_embeddings, pool_embeddings.T)
    max_similarities = np.max(similarities, axis=0)
    n_selected = int(0.25 * n_pool)
    selected_pool_indices = np.argsort(max_similarities)[-n_selected:]

    # Run attack
    result = run_correspondence_attack(
        target_embeddings=target_embeddings,
        pool_embeddings=pool_embeddings,
        victim_indices=victim_indices,
        selected_pool_indices=selected_pool_indices,
        alpha=0.5,
        k_candidates=50,
        selection_rate=0.25
    )

    # Print results
    eval_result = result["evaluation"]
    print(f"Correspondence-based MIA Attack Results:")
    print(f"  ROC AUC: {eval_result['roc_auc']:.4f}")
    print(f"  Members: {eval_result['n_members']}")
    print(f"  Non-members: {eval_result['n_nonmembers']}")
    print(f"  Member score: {eval_result['member_score_mean']:.4f} ± {eval_result['member_score_std']:.4f}")
    print(f"  Non-member score: {eval_result['nonmember_score_mean']:.4f} ± {eval_result['nonmember_score_std']:.4f}")

    return result


if __name__ == "__main__":
    example_usage()