from evosax import Strategy
from typing import Optional
from chex import Array, ArrayTree, PRNGKey
from evosax.strategies.cma_es import get_cma_elite_weights
from flax.struct import dataclass
import jax.numpy as jnp
import jax

from sves.kernels import Kernel
from evosax.utils.eigen_decomp import full_eigen_decomp
from evosax.strategies.cma_es import update_p_c, update_p_sigma, sample, update_sigma, update_covariance


@dataclass
class EvoState:
    p_sigma: Array
    p_c: Array
    C: Array
    D: Optional[Array]
    B: Optional[Array]
    mean: Array
    sigma: Array
    particles: Array
    weights: Array
    weights_truncated: Array
    state_order: Array
    best_member: Array
    best_fitness: float = jnp.finfo(jnp.float32).max
    gen_counter: int = 0
    bandwidth: float = 1.


@dataclass
class EvoParams:
    mu_eff: float
    c_1: float
    c_mu: float
    c_sigma: float
    d_sigma: float
    c_c: float
    chi_n: float
    c_m: float = 1.0
    sigma_init: float = 0.065
    alpha_min: float = 1.
    init_min: float = 0.0
    init_max: float = 0.0
    clip_min: float = -jnp.finfo(jnp.float32).max
    clip_max: float = jnp.finfo(jnp.float32).max


class SV_CMA_BB(Strategy):
    """Stein Variational CMA-ES."""
    def __init__(
        self,
        npop: int,
        subpopsize: int,
        kernel: Kernel,
        num_dims: Optional[int] = None,
        pholder_params: Optional[ArrayTree | Array] = None,
        num_iters: int = 100,
        elite_ratio: float = 0.5,
        sigma_init: float = 1.0,
        mean_decay: float = 0.0,
        n_devices: Optional[int] = None,
        **fitness_kwargs: bool | int | float
    ):
        self.npop = npop
        self.subpopsize = subpopsize
        popsize = int(npop * subpopsize)
        super().__init__(
            popsize,
            num_dims,
            pholder_params,
            mean_decay,
            n_devices,
            **fitness_kwargs
        )
        # assert 0 <= elite_ratio <= 1
        self.elite_ratio = elite_ratio
        self.elite_popsize = max(1, int(self.subpopsize * self.elite_ratio))
        self.strategy_name = "BB_SVGD_ES"

        # Set core kwargs es_params
        self.sigma_init = sigma_init
        self.kernel = kernel
        self.num_iters = num_iters

        # Robustness for int32 - squaring in hyperparameter calculations
        self.max_dims_sq = jnp.minimum(self.num_dims, 40000)

    @property
    def params_strategy(self) -> EvoParams:
        """Return default parameters of evolution strategy."""
        _, _, mu_eff, c_1, c_mu = get_cma_elite_weights(
            self.subpopsize, self.elite_popsize, self.num_dims, self.max_dims_sq
        )

        # lrate for cumulation of step-size control and rank-one update
        c_sigma = (mu_eff + 2) / (self.num_dims + mu_eff + 5)
        d_sigma = (
            1
            + 2
            * jnp.nan_to_num(jnp.maximum(0, jnp.sqrt((mu_eff - 1) / (self.num_dims + 1)) - 1), nan=0.)  # There can be numerical errors here
            + c_sigma
        )
        c_c = (4 + mu_eff / self.num_dims) / (
                self.num_dims + 4 + 2 * mu_eff / self.num_dims
        )
        chi_n = jnp.sqrt(self.num_dims) * (
                1.0
                - (1.0 / (4.0 * self.num_dims))
                + 1.0 / (21.0 * (self.max_dims_sq ** 2))
        )

        params = EvoParams(
            mu_eff=mu_eff,
            c_1=c_1,
            c_mu=c_mu,
            c_sigma=c_sigma,
            d_sigma=d_sigma,
            c_c=c_c,
            chi_n=chi_n,
            sigma_init=self.sigma_init,
        )
        return params

    def initialize_strategy(
        self, rng: PRNGKey, params: EvoParams
    ) -> EvoState:
        """`initialize` the evolution strategy."""
        weights, weights_truncated, _, _, _ = get_cma_elite_weights(
            self.subpopsize, self.elite_popsize, self.num_dims, self.max_dims_sq
        )
        # Initialize evolution paths & covariance matrix
        initialization = jax.random.uniform(
            rng,
            (self.npop, self.num_dims),
            minval=params.init_min,
            maxval=params.init_max,
        )

        state = EvoState(
            p_sigma=jnp.zeros((self.npop, self.num_dims)),
            p_c=jnp.zeros((self.npop, self.num_dims)),
            sigma=jnp.ones(self.npop) * params.sigma_init,
            mean=initialization,
            particles=initialization,
            C=jnp.tile(jnp.eye(self.num_dims), (self.npop, 1, 1)),
            D=None,
            B=None,
            weights=weights,
            weights_truncated=weights_truncated,
            best_member=initialization[0],  # Take any random member of the means
            state_order=jnp.arange(self.npop),
            bandwidth=self.kernel.bandwidth
        )
        return state

    def ask_strategy(
        self, rng: PRNGKey, state: EvoState, params: EvoParams
    ) -> [Array, EvoState]:
        """`ask` for new parameter candidates to evaluate next."""
        Cs, Bs, Ds = jax.vmap(full_eigen_decomp, (0, 0, 0, None))(
            state.C, state.B, state.D, state.gen_counter
        )
        keys = jax.random.split(rng, num=self.npop)
        x = jax.vmap(sample, (0, 0, 0, 0, 0, None, None))(
            keys,
            state.mean,
            state.sigma,
            Bs,
            Ds,
            self.num_dims,
            self.subpopsize,
        )

        # Reshape for evaluation
        x = x.reshape(self.popsize, self.num_dims)

        return x, state.replace(C=Cs, B=Bs, D=Ds)

    def tell_strategy(
        self,
        x: Array,
        fitness: Array,
        state: EvoState,
        params: EvoParams,
    ) -> EvoState:
        """`tell` performance data for strategy state update."""
        x = x.reshape(self.npop, self.subpopsize, self.num_dims)
        fitness = fitness.reshape(self.npop, self.subpopsize)
        mean_fitness = jnp.min(fitness, axis=-1)
        sorted_indices = mean_fitness.argsort()

        # Compute grads
        y_ks, y_ws = jax.vmap(cmaes_grad, (0, 0, 0, 0, None))(
            x,
            fitness,
            state.mean,
            state.sigma,
            state.weights_truncated
        )

        # Compute kernel grads
        bandwidth = state.bandwidth
        kernel_grads = jax.vmap(
            lambda xi: jnp.mean(
                jax.vmap(lambda xj: jax.grad(self.kernel)(xj, xi, bandwidth))(state.mean),
                axis=0
            )
        )(state.mean)

        # Update means using the kernel gradients
        alpha = jnp.maximum(jnp.log(self.num_iters / (state.gen_counter + 1)), params.alpha_min)
        projected_steps = y_ws + alpha * kernel_grads / state.sigma[:, None]
        means = state.mean + params.c_m * state.sigma[:, None] * projected_steps

        # Sigma update
        p_sigmas, C_2s, Cs, Bs, Ds = jax.vmap(update_p_sigma, (0, 0, 0, 0, 0, None, None, None))(
            state.C,
            state.B,
            state.D,
            state.p_sigma,
            projected_steps,
            params.c_sigma,
            params.mu_eff,
            state.gen_counter,
        )

        p_cs, norms_p_sigma, h_sigmas = jax.vmap(update_p_c, (0, 0, 0, None, 0, None, None, None, None))(
            means,
            p_sigmas,
            state.p_c,
            state.gen_counter + 1,
            projected_steps,
            params.c_sigma,
            params.chi_n,
            params.c_c,
            params.mu_eff,
        )

        Cs = jax.vmap(update_covariance, (0, 0, 0, 0, 0, 0, None, None, None, None))(
            means,
            p_cs,
            Cs,
            y_ks,
            h_sigmas,
            C_2s,
            state.weights,
            params.c_c,
            params.c_1,
            params.c_mu
        )

        sigmas = jax.vmap(update_sigma, (0, 0, None, None, None))(
            state.sigma,
            norms_p_sigma,
            params.c_sigma,
            params.d_sigma,
            params.chi_n,
        )
        # Clip for numerical stability
        sigmas = jnp.clip(jnp.where(sigmas < 1e-8, state.sigma, sigmas), 1e-8, 1e8)

        return state.replace(
            mean=means, p_sigma=p_sigmas, C=Cs, B=Bs, D=Ds, p_c=p_cs, sigma=sigmas, state_order=sorted_indices, particles=means
        )


def cmaes_grad(
    x: Array,
    fitness: Array,
    mean: Array,
    sigma: float,
    weights_truncated: Array,
) -> [Array, Array]:
    # get sorted solutions
    concat_p_f = jnp.hstack([jnp.expand_dims(fitness, 1), x])
    sorted_solutions = concat_p_f[concat_p_f[:, 0].argsort()]
    # get the scores
    x_k = sorted_solutions[:, 1:]  # ~ N(m, σ^2 C)
    y_k = (x_k - mean) / sigma  # ~ N(0, C)
    grad = jnp.dot(weights_truncated.T, y_k)  # y_w can be seen as score of CMA-ES

    return y_k, grad
