import jax
import optax
import jax.numpy as jnp
from utils import (
    modified_lamb,
    upsample_grid_nn,
    gumbel_softmax_topk,
)


CRITERION_COLLECTION = {
    "ce": optax.softmax_cross_entropy,
    "bce": lambda x, y: optax.sigmoid_binary_cross_entropy(x, y > 0).mean(-1),
    "mse": lambda x, y: optax.squared_error(x, y).mean(-1),
    "kl": lambda x, y: optax.losses.kl_divergence_with_log_targets(
        # Normalize the logits to enforce probability distribution
        jax.nn.log_softmax(x),
        jax.nn.log_softmax(y),
    ).mean(-1),
}

OPTIMIZER_COLLECTION = {
    "adamw": optax.adamw,
    "lamb": modified_lamb,
}


def patch_selection(
    args, rng, zoom_map, attn_maps, temperature=0.1, noise_coeff=0.1, is_baseline=False
):
    """Returns the indices of the patches to keep (optionally conditioned on the zoom_map)."""
    bs, num_patches = zoom_map.shape
    top_k_size = args.top_k_range.max if args.top_k_range is not None else args.top_k

    patch_selection_method = args.patch_selection_method
    if is_baseline:
        # The baseline for reinforce gumbel topk is simple greedy
        # if patch_selection_method == "reinforcement:reinforce-gumbel-topk":
        patch_selection_method = "topk-oracle"

    mask = None
    match patch_selection_method:
        case "random":
            # Randomly select args.top_k patches to keep without replacement.
            keep_patch_indices = jax.vmap(
                lambda rng_i: jax.random.choice(
                    rng_i, num_patches, shape=(top_k_size,), replace=False
                )
            )(jax.random.split(rng, bs))  # uniformly sample k indices

        case "topk-zoomer":
            # Select the top args.top_k patches based on zoom_map.
            keep_patch_indices = jnp.argsort(zoom_map, axis=1)[
                :, -top_k_size:
            ]  # (bs, k)
            keep_patch_indices = keep_patch_indices[:, ::-1]  # (bs, k)

        case "topk-oracle":
            # Select the top args.top_k patches based on zoom_map.
            keep_patch_indices = jnp.argsort(attn_maps, axis=1)[
                :, -top_k_size:
            ]  # (bs, k)
            keep_patch_indices = keep_patch_indices[:, ::-1]  # (bs, k)

        case "gumbel-topk":
            # Select the top args.top_k patches differentiably based on zoom_map.
            mask, keep_patch_indices = gumbel_softmax_topk(
                rng, zoom_map, top_k_size, temperature, noise_coeff
            )

        case "reinforcement:reinforce":
            probs = jax.nn.softmax(zoom_map / temperature, axis=1)

            def sample_fn(k, p):
                return jax.random.choice(
                    k, p.shape[0], shape=(top_k_size,), replace=False, p=p
                )

            rngs = jax.random.split(rng, zoom_map.shape[0])
            keep_patch_indices = jax.vmap(sample_fn)(rngs, probs)

        case "reinforcement:reinforce-gauss":
            # probs = jax.nn.softmax(zoom_map / temperature, axis=1)

            # def sample_fn(k, p):
            #     return jax.random.choice(
            #         k, p.shape[0], shape=(top_k_size,), replace=False, p=p
            #     )

            # rngs = jax.random.split(rng, zoom_map.shape[0])
            # keep_patch_indices = jax.vmap(sample_fn)(rngs, probs)
            keep_patch_indices = jnp.argsort(zoom_map, axis=1)[
                :, -top_k_size:
            ]  # (bs, k)
            keep_patch_indices = keep_patch_indices[:, ::-1]  # (bs, k)

        case _:
            raise ValueError(f"Unknown patch selection type: {type}")
    return mask, keep_patch_indices  # (bs, num_patches), (bs, k)


def upsample_patch_embeds(
    args, keep_patch_indices, student_patch_embeds, masked_ids=None
):
    if args.upsample_features.type == "NN":
        # needs to be formatted "NN-K"
        bs, _, dim = student_patch_embeds.shape

        student_patch_embeds = upsample_grid_nn(
            all_keep_ids=keep_patch_indices,
            all_keep_values=student_patch_embeds,
            grid_size=37,  # hardcode for DINO-V2
            K=args.upsample_features.K,
            distance_power=args.upsample_features.distance_power,
            masked_ids=masked_ids,
        ).reshape(
            bs, -1, dim
        )  # same reshape method as PatchEmbed: shape (bs, 37*37, dim)
        return student_patch_embeds

    else:
        raise ValueError(f"Unknown upsample features type: {args.upsample_features}")
