from functools import partial

import jax
import jax.numpy as jnp
import jax.random
from jax.tree_util import tree_map

from numpyro import handlers
from numpyro.contrib.einstein import SteinVI
from numpyro.contrib.einstein.steinvi import _numel
from numpyro.contrib.einstein.util import batch_ravel_pytree
from numpyro.distributions.transforms import IdentityTransform
from numpyro.util import ravel_pytree


class SteinVIForces(SteinVI):
    def __init__(self, *args, **kwargs):
        super(SteinVIForces, self).__init__(*args, **kwargs)

    def _svgd_forces(self, rng_key, unconstr_params, *args, **kwargs):
        # 0. Separate model and guide parameters, since only guide parameters are updated using Stein
        classic_uparams = {
            p: v
            for p, v in unconstr_params.items()
            if p not in self.guide_param_names or self.classic_guide_params_fn(p)
        }
        stein_uparams = {
            p: v for p, v in unconstr_params.items() if p not in classic_uparams
        }
        # 1. Collect each guide parameter into monolithic particles that capture correlations
        # between parameter values across each individual particle
        stein_particles, unravel_pytree, unravel_pytree_batched = batch_ravel_pytree(
            stein_uparams, nbatch_dims=1
        )
        particle_info, _ = self._calc_particle_info(
            stein_uparams, stein_particles.shape[0]
        )

        # 2. Calculate loss and gradients for each parameter
        def scaled_loss(rng_key, classic_params, stein_params):
            params = {**classic_params, **stein_params}
            loss_val = self.loss.loss(
                rng_key,
                params,
                handlers.scale(self._inference_model, self.loss_temperature),
                self.guide,
                *args,
                **kwargs,
            )
            return -loss_val

        def kernel_particle_loss_fn(ps):
            return scaled_loss(
                rng_key,
                self.constrain_fn(classic_uparams),
                self.constrain_fn(unravel_pytree(ps)),
            )

        def particle_transform_fn(particle):
            params = unravel_pytree(particle)

            tparams = self.particle_transform_fn(params)
            tparticle, _ = ravel_pytree(tparams)
            return tparticle

        tstein_particles = jax.vmap(particle_transform_fn)(stein_particles)

        loss, particle_ljp_grads = jax.vmap(
            jax.value_and_grad(kernel_particle_loss_fn)
        )(tstein_particles)
        classic_param_grads = jax.vmap(
            lambda ps: jax.grad(
                lambda cps: scaled_loss(
                    rng_key,
                    self.constrain_fn(cps),
                    self.constrain_fn(unravel_pytree(ps)),
                )
            )(classic_uparams)
        )(stein_particles)
        classic_param_grads = tree_map(partial(jnp.mean, axis=0), classic_param_grads)

        # 3. Calculate kernel on monolithic particle
        kernel = self.kernel_fn.compute(
            stein_particles, particle_info, kernel_particle_loss_fn
        )

        # 4. Calculate the attractive force and repulsive force on the monolithic particles
        attractive_force = jax.vmap(
            lambda y: jnp.sum(
                jax.vmap(
                    lambda x, x_ljp_grad: self._apply_kernel(kernel, x, y, x_ljp_grad)
                )(tstein_particles, particle_ljp_grads),
                axis=0,
            )
        )(tstein_particles)

        repulsive_force = jax.vmap(
            lambda y: jnp.sum(
                jax.vmap(
                    lambda x: self.repulsion_temperature
                    * self._kernel_grad(kernel, x, y)
                )(tstein_particles),
                axis=0,
            )
        )(tstein_particles)

        def single_particle_grad(particle, attr_forces, rep_forces):
            def _nontrivial_jac(var_name, var):
                if isinstance(self.particle_transforms[var_name], IdentityTransform):
                    return None
                return jax.jacfwd(self.particle_transforms[var_name].inv)(var)

            def _update_force(attr_force, rep_force, jac):
                force = attr_force.reshape(-1) + rep_force.reshape(-1)
                if jac is not None:
                    force = force @ jac.reshape(
                        (_numel(jac.shape[: len(jac.shape) // 2]), -1)
                    )
                return force.reshape(attr_force.shape)

            reparam_jac = {
                name: tree_map(lambda var: _nontrivial_jac(name, var), variables)
                for name, variables in unravel_pytree(particle).items()
            }
            jac_params = tree_map(
                _update_force,
                unravel_pytree(attr_forces),
                unravel_pytree(rep_forces),
                reparam_jac,
            )
            jac_particle, _ = ravel_pytree(jac_params)
            return jac_particle

        rep_grads = (
            jax.vmap(single_particle_grad)(
                stein_particles, jnp.zeros_like(attractive_force), repulsive_force
            )
            / self.num_particles
        )

        attract_grads = (
            jax.vmap(single_particle_grad)(
                stein_particles, attractive_force, jnp.zeros_like(attractive_force)
            )
            / self.num_particles
        )

        # 6. Return loss and gradients (based on parameter forces)
        attrac_grads = tree_map(
            lambda x: x,
            {**classic_param_grads, **unravel_pytree_batched(attract_grads)},
        )
        rep_grads = tree_map(
            lambda x: x, {**classic_param_grads, **unravel_pytree_batched(rep_grads)}
        )
        return attrac_grads, rep_grads
