import dataclasses
import functools

import chex
import jax
import jax.numpy as jnp
import optax
from clu import metrics as clu_metrics

from tabular_mvdrl.kernels import Kernel
from tabular_mvdrl.mmd import mmd2
from tabular_mvdrl.models import EWPModel
from tabular_mvdrl.state import MetricTrainState
from tabular_mvdrl.trainer import MVDRLTransferTrainer
from tabular_mvdrl.types import MRPTransitionBatch
from tabular_mvdrl.utils import jitpp
from tabular_mvdrl.utils.jitpp import Bind, Donate, Static

LOSS_MMD = "loss__mmd"


@dataclasses.dataclass(frozen=True, kw_only=True)
class EWPTDTrainer(MVDRLTransferTrainer[MetricTrainState]):
    num_atoms: int
    optim: optax.GradientTransformation
    kernel: Kernel
    discount: float

    @property
    def identifier(self):
        return f"EWP-TD:{self.num_atoms}"

    @functools.cached_property
    def metrics(self) -> clu_metrics.Collection:
        metric_tags = [LOSS_MMD]
        metric_keepers = {
            tag: clu_metrics.Average.from_output(tag) for tag in metric_tags
        }
        return clu_metrics.Collection.create(**metric_keepers)

    @functools.cached_property
    def state(self) -> MetricTrainState:
        model = EWPModel(self.env.num_states, self.env.reward_dim, self.num_atoms)
        init_rng = jax.random.PRNGKey(self.seed + 1)
        params = model.init(init_rng, jnp.ones(2, dtype=jnp.int32))
        return MetricTrainState.create(
            params=params,
            apply_fn=model.apply,
            tx=self.optim,
            metrics=self.metrics.empty(),
        )

    @jitpp.jit
    @staticmethod
    @chex.assert_max_traces(1)
    def train_step(
        rng: chex.PRNGKey,
        state: Donate[MetricTrainState],
        batch: MRPTransitionBatch,
        *,
        kernel: Bind[Static[Kernel]],
        num_atoms: Bind[int],
        discount: Bind[float],
    ) -> MetricTrainState:
        def _mmd_loss(eta_pred: chex.Array, eta_target: chex.Array) -> chex.Scalar:
            num_atoms = eta_pred.shape[0]
            uniform_probs = jnp.ones(num_atoms) / num_atoms
            return mmd2(kernel, eta_pred, eta_target, uniform_probs, uniform_probs)

        @jax.value_and_grad
        def loss_fn(params: chex.ArrayTree, batch_: MRPTransitionBatch):
            eta_t = jax.vmap(state.apply_fn, in_axes=(None, 0))(params, batch_.o_t)
            eta_tp1 = jax.vmap(state.apply_fn, in_axes=(None, 0))(
                state.params, batch_.o_tp1
            )
            eta_target = batch_.r_t[:, None, ...] + discount * eta_tp1
            return jnp.mean(jax.vmap(_mmd_loss)(eta_t, eta_target))

        loss, grads = loss_fn(state.params, batch)
        grads = jax.tree_util.tree_map(lambda x: x * num_atoms, grads)
        metrics = {LOSS_MMD: loss}
        return state.apply_gradients(
            grads, state.metrics.single_from_model_output(**metrics)
        )
