import jax
import jax.numpy as jnp
import chex
from evosax import Strategy, CMA_ES, EvoParams, ParameterReshaper, OpenES
from typing import Optional, Union, Tuple
from flax.struct import dataclass
from evosax.strategies.cma_es import get_cma_elite_weights

from sves.strategies.mc_svgd import OESParams, OESState

@dataclass
class EvoState:
    p_sigma: chex.Array
    p_c: chex.Array
    C: chex.Array
    D: Optional[chex.Array]
    B: Optional[chex.Array]
    mean: chex.Array
    sigma: chex.Array
    particles: chex.Array
    weights: chex.Array
    weights_truncated: chex.Array
    best_member: chex.Array
    best_fitness: float = jnp.finfo(jnp.float32).max
    gen_counter: int = 0



class ParallelCMAES(Strategy):
    """Wrapper for CMA-ES with multiple seeds in parallel."""
    def __init__(
        self,
        npop: int,
        popsize: int,
        num_dims: Optional[int] = None,
        pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
        elite_ratio: float = 0.5,
        sigma_init: float = 1.0,
        mean_decay: float = 0.0,
        n_devices: Optional[int] = None,
        **fitness_kwargs: Union[bool, int, float]
    ):
        self.use_param_reshaper = pholder_params is not None
        if self.use_param_reshaper:
            self.param_reshaper = ParameterReshaper(pholder_params, n_devices=npop)
        self.base_strategy = CMA_ES(
            popsize,
            num_dims,
            pholder_params,
            elite_ratio,
            sigma_init,
            mean_decay,
            1, # n_devices=1
            **fitness_kwargs
        )
        self.npop = npop
        self.strategy_name = "Parallel CMA-ES"

    @property
    def default_params(self) -> EvoParams:
        """Return default parameters of evolution strategy."""
        params = self.base_strategy.params_strategy
        return params

    def initialize_strategy(self, rng: chex.PRNGKey, params: EvoParams) -> EvoState:
        # Create set of init particles with same rng as other methods to make sure you start from the same
        init_particles = jax.random.uniform(
            rng,
            (self.npop, self.base_strategy.num_dims,),
            minval=params.init_min,
            maxval=params.init_max,
        )
        keys = jax.random.split(rng, self.npop)
        weights, weights_truncated, _, _, _ = get_cma_elite_weights(
            self.base_strategy.popsize,
            self.base_strategy.elite_popsize,
            self.base_strategy.num_dims,
            self.base_strategy.max_dims_sq
        )
        def init_state(rng: chex.PRNGKey) -> EvoState:
            pholder_init = jax.random.uniform(
                rng,
                (self.base_strategy.num_dims,),
                minval=params.init_min,
                maxval=params.init_max,
            )
            state = EvoState(
                p_sigma=jnp.zeros(self.base_strategy.num_dims),
                p_c=jnp.zeros(self.base_strategy.num_dims),
                sigma=params.sigma_init,
                mean=pholder_init,
                particles=pholder_init,
                C=jnp.eye(self.base_strategy.num_dims),
                D=None,
                B=None,
                weights=weights,
                weights_truncated=weights_truncated,
                best_member=pholder_init,
            )
            return state
        states = jax.vmap(init_state)(keys)
        return jax.vmap(lambda s, p: s.replace(mean=p, particles=p))(states, init_particles)

    def ask(
        self,
        rng: chex.PRNGKey,
        state: EvoState,
        params: Optional[EvoParams] = None,
    ) -> Tuple[Union[chex.Array, chex.ArrayTree], EvoState]:
        """Overwrite the ask and tell of the Strategy class intentionally here."""
        x, states = jax.vmap(self.base_strategy.ask, (None, 0, None))(rng, state, params)
        if self.use_param_reshaper:
            x = self.param_reshaper.flatten(x)                # Makes an array out of the dict
            x = self.base_strategy.param_reshaper.reshape(x)  # Base strategy assumes single pop.
        else:
            x = x.reshape(-1, self.base_strategy.num_dims)
        return x, states

    def tell(
        self,
        x: Union[chex.Array, chex.ArrayTree],
        fitness: chex.Array,
        state: EvoState,
        params: Optional[EvoParams] = None,
    ) -> chex.ArrayTree:
        """Overwrite the ask and tell of the Strategy class intentionally here."""
        if self.use_param_reshaper:
            x = self.base_strategy.param_reshaper.flatten(x)
            x = jax.tree_util.tree_map(
                lambda xi: xi.reshape(self.npop, self.base_strategy.popsize, self.base_strategy.num_dims),
                x
            )
        else:
            x = x.reshape(self.npop, self.base_strategy.popsize, self.base_strategy.num_dims)
        fitness = fitness.reshape(self.npop, self.base_strategy.popsize)
        states = jax.vmap(self.base_strategy.tell, (0, 0, 0, None))(x, fitness, state, params)
        return states.replace(particles=state.mean)



class ParallelOpenES(Strategy):
    """Wrapper for CMA-ES with multiple seeds in parallel."""
    def __init__(
        self,
        npop: int,
        subpopsize: int,
        num_dims: Optional[int] = None,
        pholder_params: Optional[chex.ArrayTree | chex.Array] = None,
        use_antithetic_sampling: bool = True,
        opt_name: str = "adam",
        lrate_init: float = 0.05,
        lrate_decay: float = 1.0,
        lrate_limit: float = 0.001,
        sigma_init: float = 0.03,
        sigma_decay: float = 1.0,
        sigma_limit: float = 0.01,
        mean_decay: float = 0.0,
        n_devices: Optional[int] = None,
        **fitness_kwargs: bool | int | float
    ):
        self.use_param_reshaper = pholder_params is not None
        if self.use_param_reshaper:
            self.param_reshaper = ParameterReshaper(pholder_params, n_devices=npop)
        self.base_strategy = OpenES(
            subpopsize,
            num_dims,
            pholder_params,
            use_antithetic_sampling,
            opt_name,
            lrate_init,
            lrate_decay,
            lrate_limit,
            sigma_init,
            sigma_decay,
            sigma_limit,
            mean_decay,
            1,  # n_devices=1
            ** fitness_kwargs
        )
        self.npop = npop
        self.strategy_name = "Parallel MC SVGD"

    @property
    def default_params(self) -> EvoParams:
        """Return default parameters of evolution strategy."""
        params = self.base_strategy.params_strategy
        return params

    def initialize_strategy(self, rng: chex.PRNGKey, params: OESParams) -> OESState:
        keys = jax.random.split(rng, self.npop)

        def init_state(rng: chex.PRNGKey) -> OESState:
            initialization = jax.random.uniform(
                rng,
                (self.base_strategy.num_dims,),
                minval=params.init_min,
                maxval=params.init_max,
            )
            state = OESState(
                mean=initialization,
                sigma=jnp.ones(self.base_strategy.num_dims,) * params.sigma_init,
                opt_state=self.base_strategy.optimizer.initialize(params.opt_params),
                best_member=initialization,
                particles=initialization
            )
            return state

        return jax.vmap(init_state)(keys)

    def ask(
        self,
        rng: chex.PRNGKey,
        state: EvoState,
        params: Optional[EvoParams] = None,
    ) -> Tuple[Union[chex.Array, chex.ArrayTree], EvoState]:
        """Overwrite the ask and tell of the Strategy class intentionally here."""
        x, states = jax.vmap(self.base_strategy.ask, (None, 0, None))(rng, state, params)
        if self.use_param_reshaper:
            x = self.param_reshaper.flatten(x)                # Makes an array out of the dict
            x = self.base_strategy.param_reshaper.reshape(x)  # Base strategy assumes single pop.
        else:
            x = x.reshape(-1, self.base_strategy.num_dims)
        return x, states

    def tell(
        self,
        x: Union[chex.Array, chex.ArrayTree],
        fitness: chex.Array,
        state: EvoState,
        params: Optional[EvoParams] = None,
    ) -> chex.ArrayTree:
        """Overwrite the ask and tell of the Strategy class intentionally here."""
        if self.use_param_reshaper:
            x = self.base_strategy.param_reshaper.flatten(x)
            x = jax.tree_util.tree_map(
                lambda xi: xi.reshape(self.npop, self.base_strategy.popsize, self.base_strategy.num_dims),
                x
            )
        else:
            x = x.reshape(self.npop, self.base_strategy.popsize, self.base_strategy.num_dims)
        fitness = fitness.reshape(self.npop, self.base_strategy.popsize)
        states = jax.vmap(self.base_strategy.tell, (0, 0, 0, None))(x, fitness, state, params)
        return states.replace(particles=state.mean)
