from functools import partial

import jax
import jax.numpy as jnp


@jax.jit
def compute_nearest_neighbor_similarity(pool_sample, target_samples):
    pool_norm = jnp.linalg.norm(pool_sample, keepdims=True)
    target_norm = jnp.linalg.norm(target_samples, axis=1, keepdims=True)

    pool_normalized = pool_sample / pool_norm
    target_normalized = target_samples / target_norm

    return jnp.einsum("i,ji->j", pool_normalized, target_normalized).max(axis=-1)


@jax.jit
def compute_nearest_neighbor_similarities(pool_samples, target_samples):
    return jax.lax.map(
        partial(compute_nearest_neighbor_similarity, target_samples=target_samples),
        pool_samples,
        batch_size=10_000,
    )


@jax.jit
def compute_nearest_neighbor_idx(pool_sample, target_samples):
    pool_norm = jnp.linalg.norm(pool_sample, keepdims=True)
    target_norm = jnp.linalg.norm(target_samples, axis=1, keepdims=True)

    pool_normalized = pool_sample / pool_norm
    target_normalized = target_samples / target_norm

    return jnp.einsum("i,ji->j", pool_normalized, target_normalized).argmax(axis=-1)


@jax.jit
def compute_nearest_neighbor_idxs(pool_samples, target_samples):
    return jax.lax.map(
        partial(compute_nearest_neighbor_idx, target_samples=target_samples),
        pool_samples,
        batch_size=10_000,
    )


@jax.jit
def compute_votes(pool_sample, target_samples, victim_score):
    votes = jnp.zeros_like(target_samples.shape[0])
    pool_normalized = pool_sample / jnp.linalg.norm(pool_sample)
    targets_normalized = target_samples / jnp.linalg.norm(
        target_samples, axis=1, keepdims=True
    )
    sims = jnp.einsum("i,ji->j", pool_normalized, targets_normalized)
    victim_nn_idx = jnp.argmin(jnp.abs(sims - victim_score))
    votes = jnp.where(
        sims > victim_score,
        -jnp.ones_like(target_samples.shape[0]),
        jnp.zeros_like(target_samples.shape[0]),
    )
    votes = votes.at[victim_nn_idx].set(1)
    return votes


@jax.jit
def process_batch(pool_samples_batch, victim_scores_batch, target_samples):
    return jax.vmap(compute_votes, in_axes=(0, None, 0))(
        pool_samples_batch, target_samples, victim_scores_batch
    ).sum(0)


def imagebased_attack(pool_samples, target_samples, victim_scores, batch_size=5_000):
    num_samples = len(pool_samples)
    all_votes = []
    for k in range(victim_scores.shape[0]):
        votes = jnp.zeros(target_samples.shape[0])
        for i in range(0, num_samples, batch_size):
            batch_end = min(i + batch_size, num_samples)
            pool_batch = pool_samples[i:batch_end]
            scores_batch = victim_scores[k, i:batch_end]
            batch_votes = process_batch(pool_batch, scores_batch, target_samples)
            votes = votes + batch_votes
        all_votes.append(votes)
    all_votes = jnp.stack(all_votes)
    return all_votes
