from __future__ import annotations

import argparse
from typing import Callable

import jax
import flax.linen as nn
import jax.numpy as jnp
from chex import Array, ArrayTree

from dataset import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from modeling import ViT, Zoomer
from training_util import CRITERION_COLLECTION, patch_selection, upsample_patch_embeds

def compute_log_prob_without_replacement(logits, selected_indices, do_mean):
    """
    Args:
        logits: [B, N] array of logits.
        selected_indices: [B, K] array of selected indices per example.

    Returns:
        [B] array of total log-probs for each example.
    """

    def per_example(logits_i, selected_i):  # logits_i: (N,), selected_i: (K,)
        def step(carry, selected_idx):
            logits, mask = carry
            masked_logits = jnp.where(mask, logits, -jnp.inf)  # (N,)
            log_probs = jax.nn.log_softmax(masked_logits)  # (N,)
            logp = log_probs[selected_idx]  # (K,)
            mask = mask.at[selected_idx].set(False)  # (N,)
            return (logits, mask), logp

        N = logits_i.shape[0]
        init_mask = jnp.ones(N, dtype=bool)  # (N,)
        (_, _), logps = jax.lax.scan(step, (logits_i, init_mask), selected_i)  # (K,)

        if do_mean:
            return jnp.mean(logps)  # (1,)
        else:
            return jnp.sum(logps)  # (1,)

    return jax.vmap(per_example)(logits, selected_indices)  # (B,)


def compute_log_prob(logits, keep_patch_indices, is_naive=False, do_mean=False):
    if is_naive:
        # Naive implementation
        log_probs = jax.nn.log_softmax(logits)
        mask = jnp.zeros_like(log_probs, dtype=bool)
        mask = mask.at[jnp.arange(log_probs.shape[0])[:, None], keep_patch_indices].set(
            True
        )
        if do_mean:
            return jnp.mean(log_probs * mask, axis=-1)
        else:
            return jnp.sum(log_probs * mask, axis=-1)
    else:
        return compute_log_prob_without_replacement(logits, keep_patch_indices, do_mean)


class TrainModule(nn.Module):
    args: argparse.Namespace
    student: ViT
    zoomer: Zoomer
    criterion: Callable[[Array, Array], Array] = CRITERION_COLLECTION["ce"]
    zoom_map_criterion: Callable[[Array, Array], Array] = CRITERION_COLLECTION["bce"]

    def compute_distillation_loss(
        self,
        cls_embeds,
        reg_embeds,
        patch_embeds,
        teacher_cls_embeds,
        teacher_reg_embeds,
        teacher_patch_embeds,
        zoom_map,
        attn_maps,
    ):
        """Computes the distillation loss."""
        loss = 0
        rl_loss = 0
        metrics = {}
        for key, weight in self.args.distillation_losses.items():
            if key == "cls":
                loss_item = self.criterion(cls_embeds, teacher_cls_embeds)
            elif key == "reg":
                loss_item = self.criterion(reg_embeds, teacher_reg_embeds).mean(-1)
            elif key == "patch":
                loss_item = self.criterion(patch_embeds, teacher_patch_embeds).mean(-1)
            elif key == "map":
                loss_item = self.zoom_map_criterion(zoom_map, attn_maps)
            else:
                raise ValueError(f"Unknown distillation loss type: {key}")
            if key in ["cls", "patch"]:
                rl_loss += loss_item
            loss += weight * loss_item
            metrics[f"loss_{key}_distill"] = loss_item
        return loss, rl_loss, metrics

    def __call__(
        self,
        images: Array,
        labels: Array,
        k_patches: Array,
        gumbel_temperature: Array,
        gumbel_noise_coeff: Array,
        rl_coeff: Array,
        kl_coeff: Array,
        teacher_params: ArrayTree,
    ) -> ArrayTree:
        # Preprocess images.
        images = jnp.moveaxis(images, 1, 3).astype(jnp.float32) / 0xFF
        images = (images - IMAGENET_DEFAULT_MEAN) / IMAGENET_DEFAULT_STD

        # Forward pass for teacher using its parameters.
        _, teacher_cls_embeds, teacher_reg_embeds, teacher_patch_embeds, attn_maps = (
            self.student.apply(
                {"params": teacher_params},
                images,
                attn_aggregate=self.args.attn_aggregate,
                zoom_map_criterion=self.args.zoom_map_criterion,
            )
        )  # (bs, dim), (bs, num_registers, dim), (bs, num_patches, dim)

        # Forward pass for zoomer.
        zoom_map, zoom_cls_token, zoom_reg_token, _patch_token = self.zoomer(
            images
        )  # (bs, num_patches * num_patches), (bs, dim), (bs, num_registers, dim), (bs, num_low_res_patches, dim)

        # Optionally extract a subset of the zoomer's outputs to condition the student.
        zoom_cls_token = jnp.expand_dims(zoom_cls_token, axis=1)  # (bs, 1, dim)
        conditioning_tokens = {
            "cls_token": zoom_cls_token,
            "reg_token": zoom_reg_token,
        }
        conditioning_tokens = {
            key: conditioning_tokens[key] for key in self.args.conditioning_tokens
        }

        ########################################
        ##### BASELINE #########################
        ########################################

        # Select the patches to zoom into.
        rng = self.make_rng("patch_selection")
        baseline_gumbel_mask, baseline_keep_patch_indices = patch_selection(
            self.args,
            rng,
            zoom_map,
            attn_maps,
            temperature=gumbel_temperature,
            noise_coeff=gumbel_noise_coeff,
            is_baseline=True,
        )  # shape: (bs, k)

        # Forward pass for student.
        _, cls_embeds, reg_embeds, patch_embeds = self.student(
            images,
            k_patches=k_patches,
            mask=baseline_gumbel_mask,
            keep_patch_indices=baseline_keep_patch_indices,
            conditioning_tokens=conditioning_tokens,
        )  # (bs, dim), (bs, num_registers, dim), (bs, num_keep_patches, dim)

        # Get indices to mask in upsampling. The first k are set to `num_patches` (invalid index, so unmasked).
        # The remaining ones pass their actual indices to mask.
        num_selected_patches = baseline_keep_patch_indices.shape[1]
        patch_indices = jnp.arange(num_selected_patches)
        indices_to_mask = jnp.where(
            patch_indices >= k_patches,
            patch_indices,
            num_selected_patches,  # Not masked
        )[None, :].repeat(zoom_map.shape[0], axis=0)
        patch_embeds = upsample_patch_embeds(
            self.args,
            baseline_keep_patch_indices,
            patch_embeds,
            masked_ids=indices_to_mask,
        )  # (bs, 37*37, dim)

        baseline_loss, baseline_rl_loss, baseline_metrics = (
            self.compute_distillation_loss(
                cls_embeds,
                reg_embeds,
                patch_embeds,
                teacher_cls_embeds,
                teacher_reg_embeds,
                teacher_patch_embeds,
                zoom_map,
                attn_maps,
            )
        )

        ########################################
        ##### SAMPLE ###########################
        ########################################

        # Select the patches to zoom into.
        rng = self.make_rng("patch_selection")
        sampled_gumbel_mask, sampled_keep_patch_indices = patch_selection(
            self.args,
            rng,
            zoom_map,
            attn_maps,
            temperature=gumbel_temperature,
            noise_coeff=gumbel_noise_coeff,
            is_baseline=False,
        )  # shape: (bs, k)

        # Forward pass for student.
        _, cls_embeds, reg_embeds, patch_embeds = self.student(
            images,
            k_patches=k_patches,
            mask=sampled_gumbel_mask,
            keep_patch_indices=sampled_keep_patch_indices,
            conditioning_tokens=conditioning_tokens,
        )  # (bs, dim), (bs, num_registers, dim), (bs, num_keep_patches, dim)

        # Get indices to mask in upsampling. The first k are set to `num_patches` (invalid index, so unmasked).
        # The remaining ones pass their actual indices to mask.
        num_selected_patches = sampled_keep_patch_indices.shape[1]
        patch_indices = jnp.arange(num_selected_patches)
        indices_to_mask = jnp.where(
            patch_indices >= k_patches,
            patch_indices,
            num_selected_patches,  # Not masked
        )[None, :].repeat(zoom_map.shape[0], axis=0)
        patch_embeds = upsample_patch_embeds(
            self.args,
            sampled_keep_patch_indices,
            patch_embeds,
            masked_ids=indices_to_mask,
        )  # (bs, 37*37, dim)

        sampled_loss, sampled_rl_loss, sampled_metrics = self.compute_distillation_loss(
            cls_embeds,
            reg_embeds,
            patch_embeds,
            teacher_cls_embeds,
            teacher_reg_embeds,
            teacher_patch_embeds,
            zoom_map,
            attn_maps,
        )
        log_probs = compute_log_prob(
            zoom_map,
            sampled_keep_patch_indices,
            is_naive=self.args.is_naive,
            do_mean=self.args.do_mean,
        )
        advantage = jax.lax.stop_gradient(baseline_rl_loss - sampled_rl_loss)

        # Standardize the advantage.
        if self.args.standardize_advantage:
            mean_advantage = jnp.mean(advantage)
            std_advantage = jnp.std(advantage)
            advantage = (advantage - mean_advantage) / (std_advantage + 1e-8)

        # Then clip
        if self.args.advantage_clip_coeff is not None:
            advantage = jnp.clip(
                advantage,
                -self.args.advantage_clip_coeff,
                self.args.advantage_clip_coeff,
            )
        if self.args.logprobs_min_prob is not None:
            log_probs = jnp.clip(
                log_probs, jnp.log(float(self.args.logprobs_min_prob)), 0
            )

        # Compute the reinforcement learning loss.
        reinforce_loss = -rl_coeff * jnp.mean(log_probs * advantage)  # (bs,)

        baseline_metrics["log_probs"] = jnp.mean(log_probs)
        baseline_metrics["advantage"] = jnp.mean(advantage)
        baseline_metrics["baseline_loss"] = jnp.mean(baseline_loss)
        baseline_metrics["sampled_loss"] = jnp.mean(sampled_loss)

        baseline_metrics["reinforce_loss"] = reinforce_loss
        baseline_metrics["loss"] = sampled_loss + baseline_loss + reinforce_loss

        # Add sampled metrics to baseline metrics with prefix "sampled_"
        for key, value in sampled_metrics.items():
            baseline_metrics[f"sampled_{key}"] = value

        return baseline_metrics
