from typing import Optional, Tuple

import jax
import jax.numpy as jnp
from chex import Array, ArrayTree, PRNGKey
from evosax import Strategy
from evosax.core import GradientOptimizer
from jax.scipy.special import logsumexp

from sves.kernels import Kernel
from sves.strategies.svgd import EvoState, EvoParams


class GF_SVGD(Strategy):
    """Gradient-free version of SVGD from Han & Liu (2017).

    Implements Annealed GF-SVGD, i.e., Algorithm 3 from the paper.

    Source: https://arxiv.org/pdf/1806.02775.
    """
    def __init__(
        self,
        popsize: int,
        kernel: Kernel,
        num_dims: Optional[int] = None,
        pholder_params: Optional[ArrayTree | Array] = None,
        sigma_init: float = 1.,
        num_iters: int=100,
        opt_name: str = "adam",
        lrate_init: float=1e-1,
        lrate_decay: float = 1.0,
        lrate_limit: float = 0.001,
        mean_decay: float = 0.0,
        n_devices: Optional[int] = None,
        **fitness_kwargs: bool | int | float
    ):
        super().__init__(
            popsize,
            num_dims,
            pholder_params,
            mean_decay,
            n_devices,
            **fitness_kwargs
        )
        self.strategy_name = "GF_SVGD"
        self.kernel = kernel
        assert opt_name in ["sgd", "adam", "rmsprop", "clipup", "adan"]
        self.optimizer = GradientOptimizer[opt_name](self.num_dims)
        self.lrate_init = lrate_init
        self.lrate_decay = lrate_decay
        self.lrate_limit = lrate_limit
        self.num_iters = num_iters
        self.npop = popsize
        self.sigma_init = sigma_init

    @property
    def params_strategy(self) -> EvoParams:
        """Return default parameters of evolution strategy."""
        opt_params = self.optimizer.default_params.replace(
            lrate_init=self.lrate_init,
            lrate_decay=self.lrate_decay,
            lrate_limit=self.lrate_limit,
        )
        return EvoParams(
            opt_params=opt_params,
        )

    def initialize_strategy(
        self, rng: jax.random.PRNGKey, params: EvoParams
    ) -> EvoState:
        """`initialize` the evolution strategy."""
        mu = jnp.ones(self.num_dims) * (params.init_max + params.init_min) / 2
        x_init = jax.random.multivariate_normal(
            rng,
            mean=mu,
            cov=self.sigma_init * jnp.eye(self.num_dims),
            shape=(self.popsize,),
        )
        state = EvoState(
            particles=x_init,
            opt_state=self.optimizer.initialize(params.opt_params),
            best_member=x_init[0],
            bandwidth=self.kernel.bandwidth
        )
        return state

    def ask_strategy(
        self, rng: PRNGKey, state: EvoState, params: EvoParams
    ) -> Tuple[Array, EvoState]:
        return state.particles, state

    def tell_strategy(
        self,
        x: Array,
        fitness: Array,
        state: EvoState,
        params: EvoParams,
    ) -> EvoState:
        # Construct p_{t+1}; interpolate prior & do logsumexp for stability
        mu = jnp.ones(self.num_dims) * (params.init_max + params.init_min) / 2
        prior = jax.scipy.stats.multivariate_normal.logpdf(x, mu, self.sigma_init * jnp.eye(self.num_dims))
        alpha = 1 / jnp.maximum(jnp.log(self.num_iters / (state.gen_counter + 1)), 1.)
        pt = (1 - alpha) * prior - alpha * fitness  # interpolation between prior and pt as per paper; hence we must divide by alpha
        pt = jnp.exp(pt - logsumexp(pt))        # Transform into p from log p; baseline for num. stability

        gradients = self.svgd_step(x, pt, state, params) * (-1.)  # Flip gradients for min.
        particles, opt_state = self.optimizer.step(
            state.particles, gradients, state.opt_state, params.opt_params
        )
        opt_state = self.optimizer.update(opt_state, params.opt_params)
        particles = jnp.clip(particles, params.clip_min, params.clip_max)

        return state.replace(particles=particles, opt_state=opt_state)

    def svgd_step(self, x: Array, target_densities: Array, state: EvoState, params: EvoParams, eps: float = 1e-10) -> Array:
        """Take the annealed gradient-free SVGD step.

        Algorithm 3. from Han & Liu (2016): https://arxiv.org/pdf/1806.02775.
        """
        bandwidth = state.bandwidth

        def rho(xi: Array) -> Array:
            """Simple KDE for approx. of current density."""
            weighted_densities = jax.vmap(lambda y, py: self.kernel(xi, y, bandwidth) * py)(x, target_densities)  # eq 24
            return jnp.sum(weighted_densities)

        rhos = jax.vmap(rho)(x)
        importance_weights = rhos / jnp.maximum(target_densities, eps)  # 3rd last line in alg3
        phi = lambda xi, pi: jnp.sum(
            jax.vmap(lambda xj, wj, pj: (
                    wj * (self.kernel(xj, xi, bandwidth) * jax.grad(lambda x: jnp.log(rho(x) + eps))(xj) + jax.grad(self.kernel)(xj, xi, bandwidth))
            ))(x, importance_weights, target_densities),
            axis=0
        )
        Z = jnp.sum(importance_weights)
        steps = jax.vmap(phi)(x, target_densities)

        return steps / jnp.clip(Z, eps, None)     # normalize steps; clip max steps size for stability
