from typing import Optional, Tuple

import jax
import chex
import jax.numpy as jnp
from flax.struct import dataclass
from chex import Array, ArrayTree
from evosax import Strategy
from evosax.core import OptState, OptParams, GradientOptimizer

from sves.kernels import Kernel


@dataclass
class EvoParams:
    opt_params: OptParams
    alpha_min: float = 1.
    init_min: float = 0.0
    init_max: float = 0.0
    clip_min: float = -jnp.finfo(jnp.float32).max
    clip_max: float = jnp.finfo(jnp.float32).max

@dataclass
class EvoState:
    particles: Array
    opt_state: OptState
    best_member: Array
    best_fitness: float = jnp.finfo(jnp.float32).max
    best_fitness: Array
    gen_counter: int = 0
    bandwidth: float = 1.


class OG_SVGD(Strategy):
    """Original gradient-based SVGD from Liu & Wang (2016).

    Additional annealing following later papers added.
    """

    def __init__(
        self,
        popsize: int,
        kernel: Kernel,
        num_dims: Optional[int] = None,
        pholder_params: Optional[ArrayTree | Array] = None,
        num_iters: int = 100,
        opt_name: str = "adam",
        lrate_init: float = 1e-1,
        lrate_decay: float = 1.0,
        lrate_limit: float = 0.001,
        mean_decay: float = 0.0,
        n_devices: Optional[int] = None,
        **fitness_kwargs: bool | int | float
    ):
        super().__init__(
            popsize,
            num_dims,
            pholder_params,
            mean_decay,
            n_devices,
            **fitness_kwargs
        )
        self.strategy_name = "OG_SVGD"
        self.kernel = kernel
        assert opt_name in ["sgd", "adam", "rmsprop", "clipup", "adan"]
        self.optimizer = GradientOptimizer[opt_name](self.num_dims)
        self.lrate_init = lrate_init
        self.lrate_decay = lrate_decay
        self.lrate_limit = lrate_limit
        self.num_iters = num_iters
        self.npop = popsize

    @property
    def params_strategy(self) -> EvoParams:
        """Return default parameters of evolution strategy."""
        opt_params = self.optimizer.default_params.replace(
            lrate_init=self.lrate_init,
            lrate_decay=self.lrate_decay,
            lrate_limit=self.lrate_limit,
        )
        return EvoParams(
            opt_params=opt_params,
        )

    def initialize_strategy(
            self, rng: jax.random.PRNGKey, params: EvoParams
    ) -> EvoState:
        """`initialize` the evolution strategy."""
        x_init = jax.random.uniform(
            rng,
            (self.popsize, self.num_dims),
            minval=params.init_min,
            maxval=params.init_max
        )
        state = EvoState(
            particles=x_init,
            opt_state=self.optimizer.initialize(params.opt_params),
            best_member=x_init[0],
            best_fitness=jnp.ones(self.num_dims,) * jnp.finfo(jnp.float32).max,
            bandwidth=self.kernel.bandwidth
        )
        return state

    def ask_strategy(
        self, rng: chex.PRNGKey, state: EvoState, params: EvoParams
    ) -> Tuple[chex.Array, EvoState]:
        return state.particles, state

    def tell_strategy(
        self,
        x: chex.Array,
        fitness: chex.Array,
        state: EvoState,
        params: EvoParams,
    ) -> EvoState:
        scores = fitness
        gradients = self.svgd_step(x, scores, state, params) * (-1)  # Flip gradients for min.
        particles, opt_state = self.optimizer.step(
            state.particles, gradients, state.opt_state, params.opt_params
        )
        opt_state = self.optimizer.update(opt_state, params.opt_params)
        particles = jnp.clip(particles, params.clip_min, params.clip_max)
        return state.replace(particles=particles, opt_state=opt_state)

    def svgd_step(self, x: Array, scores: Array, state: EvoState, params: EvoParams) -> Array:
        """See Alg 1 in orig. paper."""
        bandwidth = state.bandwidth
        alpha = jnp.maximum(jnp.log(self.num_iters / (state.gen_counter + 1)), params.alpha_min)
        phi = lambda xi: jnp.mean(
            jax.vmap(lambda xj, scorej: self.kernel(xj, xi, bandwidth) * scorej + alpha * jax.grad(self.kernel)(xj, xi, bandwidth))(x, scores),
            axis=0
        )
        return jax.vmap(phi)(x)
