from functools import partial

import jax
import jax.numpy as jnp


@jax.jit
def compute_nearest_neighbor_similarity(pool_sample, target_samples):
    pool_norm = jnp.linalg.norm(pool_sample, keepdims=True)
    target_norm = jnp.linalg.norm(target_samples, axis=1, keepdims=True)

    pool_normalized = pool_sample / pool_norm
    target_normalized = target_samples / target_norm

    return jnp.einsum("i,ji->j", pool_normalized, target_normalized).max(axis=-1)


@jax.jit
def compute_nearest_neighbor_similarities(pool_samples, target_samples):
    return jax.lax.map(
        partial(compute_nearest_neighbor_similarity, target_samples=target_samples),
        pool_samples,
        batch_size=10_000,
    )


@jax.jit
def compute_nearest_neighbor_idx(pool_sample, target_samples):
    pool_norm = jnp.linalg.norm(pool_sample, keepdims=True)
    target_norm = jnp.linalg.norm(target_samples, axis=1, keepdims=True)

    pool_normalized = pool_sample / pool_norm
    target_normalized = target_samples / target_norm

    return jnp.einsum("i,ji->j", pool_normalized, target_normalized).argmax(axis=-1)


@jax.jit
def compute_nearest_neighbor_idxs(pool_samples, target_samples):
    return jax.lax.map(
        partial(compute_nearest_neighbor_idx, target_samples=target_samples),
        pool_samples,
        batch_size=10_000,
    )


def imagebased_attack_one_target(pool_samples, target_samples, victim_scores):
    """
    victim_scores: binary samples of [N]
    """
    num_pool_samples = len(pool_samples)
    pool_size = 0.5
    full_idx = jax.lax.top_k(
        compute_nearest_neighbor_similarities(pool_samples, target_samples),
        int(pool_size * num_pool_samples),
    )[1]
    our_mask = jnp.zeros(pool_samples.shape[0]).astype(bool)
    our_mask = our_mask.at[full_idx].set(True)
    victim_mask = victim_scores.astype(bool)

    intersection = our_mask & victim_mask
    union = our_mask | victim_mask

    jaccard_similarity = intersection.sum() / union.sum() if union.sum() > 0 else 0

    votes = jnp.zeros(target_samples.shape[0])
    removed = jnp.zeros(target_samples.shape[0]).astype(bool)

    i = 0
    sadjusted = our_mask
    prev_similarities = []
    patience = 5
    min_improvement = 0.001

    while jaccard_similarity < 1.0:
        overweighted_idx = sadjusted & ~victim_mask
        underweighted_idx = ~sadjusted & victim_mask

        probably_removed = compute_nearest_neighbor_idxs(
            pool_samples[overweighted_idx],
            target_samples[~removed],
        )
        probably_added = compute_nearest_neighbor_idxs(
            pool_samples[underweighted_idx],
            target_samples[~removed],
        )
        votes = votes.at[jnp.where(~removed)[0][probably_removed]].add(-1)
        votes = votes.at[jnp.where(~removed)[0][probably_added]].add(1)

        removed = jnp.where(votes < 0, True, False)
        print(
            i,
            "removing ",
            probably_removed.shape,
            end="",
        )
        print(
            "adding ",
            probably_added.shape,
            "total",
            removed.sum(),
            end="..",
        )

        adjusted_idx = jax.lax.top_k(
            compute_nearest_neighbor_similarities(
                pool_samples,
                target_samples[~removed],
            ),
            int(pool_size * num_pool_samples),
        )[1]

        sadjusted = jnp.zeros(pool_samples.shape[0]).astype(bool)
        sadjusted = sadjusted.at[adjusted_idx].set(True)
        intersection = sadjusted & victim_mask
        union = sadjusted | victim_mask

        jaccard_similarity = intersection.sum() / union.sum() if union.sum() > 0 else 0
        prev_similarities.append(jaccard_similarity)

        print(f"Jaccard similarity: {jaccard_similarity:.4f}")

        if i >= patience:
            improvement = jaccard_similarity - prev_similarities[-patience]
            if improvement < min_improvement:
                print(
                    f"Early stopping: Improvement ({improvement:.4f}) below threshold ({min_improvement})"
                )
                break

        i += 1
    return votes


def imagebased_attack(pool_samples, target_samples, victim_scores):
    all_votes = []
    for vs in victim_scores:
        votes = imagebased_attack_one_target(pool_samples, target_samples, vs)
        all_votes.append(votes)
    return jnp.stack(all_votes)
