from functools import partial

import jax
import jax.numpy as jnp
from jax.scipy.stats import multivariate_normal, norm


@partial(jax.jit, static_argnums=(2,))
def give_measurements_fast(preds, keep, k=1):
    keep_mean = jnp.mean(preds, axis=0, where=jnp.expand_dims(keep, axis=1))
    not_keep_mean = jnp.mean(preds, axis=0, where=jnp.expand_dims(~keep, axis=1))

    diff = jnp.abs(keep_mean - not_keep_mean)
    return jax.lax.top_k(diff, k=k)[1]  # return indices


@jax.jit
def _lira_attack_1d(i, shadow_preds, shadow_keep, victim_preds):
    # Get single measurement index
    measurement_idx = give_measurements_fast(shadow_preds, shadow_keep[:, i])

    # Extract the relevant predictions for this measurement
    pred_measurements = jnp.take(shadow_preds, measurement_idx, axis=1)[
        :, 0
    ]  # flatten to 1D

    # Create masks for in/out groups
    mask_in = shadow_keep[:, i]
    mask_out = ~shadow_keep[:, i]

    # Compute means and standard deviations using masks
    mu_in = jnp.mean(pred_measurements, where=mask_in)
    mu_out = jnp.mean(pred_measurements, where=mask_out)

    std_in = jnp.std(pred_measurements, where=mask_in) + 1e-6  # add eps for stability
    std_out = jnp.std(pred_measurements, where=mask_out) + 1e-6

    # Get victim measurement
    scores = []
    for victim_measurement in victim_preds[:, measurement_idx[0]]:
        # Compute log likelihoods using univariate normal
        log_likelihood_in = norm.logpdf(victim_measurement, mu_in, std_in)
        log_likelihood_out = norm.logpdf(victim_measurement, mu_out, std_out)
        score = log_likelihood_in - log_likelihood_out
        scores.append(score)
    return jnp.array(scores)


@partial(jax.jit, static_argnums=(4,))
def lira_attack_nd(i, shadow_preds, shadow_keep, victim_preds, k=5):
    # Get single measurement index
    measurement_idx = give_measurements_fast(shadow_preds, shadow_keep[:, i], k=k)

    # Extract the relevant predictions for this measurement
    pred_measurements = jnp.take(shadow_preds, measurement_idx, axis=1)

    # Create masks for in/out groups
    mask_in = shadow_keep[:, i]
    mask_out = ~shadow_keep[:, i]

    # Compute means and standard deviations using masks
    mu_in = jnp.mean(pred_measurements, where=mask_in, axis=0)  # (k,)
    mu_out = jnp.mean(pred_measurements, where=mask_out, axis=0)  # (k,)

    cov_in = (
        jnp.cov(pred_measurements, where=mask_in, axis=0, rowvar=False) + 1e-6
    )  # (k, k)
    cov_out = (
        jnp.cov(pred_measurements, where=mask_out, axis=0, rowvar=False) + 1e-6
    )  # (k, k)

    # Get victim measurement
    scores = []
    for victim_measurement in victim_preds[:, measurement_idx[0]]:
        # Compute log likelihoods using univariate normal
        log_likelihood_in = multivariate_normal.logpdf(
            victim_measurement, mu_in, cov_in
        )
        log_likelihood_out = multivariate_normal.logpdf(
            victim_measurement, mu_out, cov_out
        )
        score = log_likelihood_in - log_likelihood_out
        scores.append(score)
    return jnp.array(scores)


def lira_1d(shadow_scores, shadow_keep, victim_scores, batch_size=2048):
    scores = jax.lax.map(
        partial(
            _lira_attack_1d,
            shadow_preds=shadow_scores,
            shadow_keep=shadow_keep,
            victim_preds=victim_scores,
        ),
        jnp.arange(shadow_keep.shape[1]),
        batch_size=batch_size,
    )
    return scores.T  # (n_victims, n_targets)


def stable_logpmf(x, p, eps=1e-6):
    p = jnp.clip(p, eps, 1 - eps)
    return x * jnp.log(p) + (1 - x) * jnp.log(1 - p)


@jax.jit
def _binary_means(i, shadow_preds, shadow_keep, victim_preds):
    # Get single measurement index
    measurement_idx = give_measurements_fast(shadow_preds, shadow_keep[:, i])

    # Extract the relevant predictions for this measurement
    pred_measurements = jnp.take(shadow_preds, measurement_idx, axis=1)[
        :, 0
    ]  # flatten to 1D

    # Create masks for in/out groups
    mask_in = shadow_keep[:, i]
    mask_out = ~shadow_keep[:, i]

    # Compute means and standard deviations using masks
    mu_in = jnp.mean(pred_measurements, where=mask_in)
    mu_out = jnp.mean(pred_measurements, where=mask_out)

    return stable_logpmf(victim_preds[:, measurement_idx[0]], mu_in) - stable_logpmf(
        victim_preds[:, measurement_idx[0]], mu_out
    )


def binary_means(shadow_scores, shadow_keep, victim_scores, batch_size=2048):
    scores = jax.lax.map(
        partial(
            _binary_means,
            shadow_preds=shadow_scores,
            shadow_keep=shadow_keep,
            victim_preds=victim_scores,
        ),
        jnp.arange(shadow_keep.shape[1]),
        batch_size=batch_size,
    )
    if jnp.isnan(scores).any():
        raise ValueError("NaNs in scores")
    return scores.T  # (n_victims, n_targets)
