from typing import Optional, Callable

from evosax import Strategy, FitnessShaper, OpenES
import jax
import jax.numpy as jnp
from chex import Array

from sves.benchmarks import Benchmark
from sves.kernels import RBF
from sves.strategies import SV_CMA_BB, MC_SVGD, OG_SVGD, GF_SVGD, ParallelOpenES


def eval_sampling(
        rng: jax.random.PRNGKey,
        strategy: Strategy,
        bench: Benchmark,
        n_iter: int,
        cb_freq: int = 5,
        plot_cb: Optional[Callable] = None,
        # gpu_id: Optional[int] = None
) -> Array:
    """Evaluate the sampling of a strategy on a benchmark where GT samples can be drawn."""
    # Init strategy
    rng, rng_init, rng_sample = jax.random.split(rng, 3)
    es_params = strategy.default_params.replace(
        init_min=bench.lower_bounds,
        init_max=bench.upper_bounds,
        clip_min=bench.lower_bounds - 2,
        clip_max=bench.upper_bounds + 2
    )
    state = strategy.initialize(rng_init, es_params)
    shaper = FitnessShaper(centered_rank=True)

    # Get objective
    objective_fn, score_fn = bench.get_objective_derivative()

    samples = []
    for t in range(n_iter):
        rng, rng_gen = jax.random.split(rng)
        x, state = strategy.ask(rng_gen, state, es_params)
        fitness = score_fn(x) if isinstance(strategy, OG_SVGD) else objective_fn(x)  # Evaluate score for gradient-based SVGD
        shaped_fitness = shaper.apply(x, fitness) if isinstance(strategy, (MC_SVGD, OpenES, ParallelOpenES)) else fitness
        state = strategy.tell(x, shaped_fitness, state, es_params)

        # Append samples
        samples.append(state.particles)

        if t % cb_freq == 0:
            print(t + 1, fitness.min())
            if plot_cb:
                plot_cb(state.particles)

    return jnp.array(samples)


def run_cfg_cma(
    key: jax.random.PRNGKey, 
    npart: int, 
    nsubpop: int, 
    er: float, 
    sig: float, 
    kw: float, 
    nrep: int, 
    num_generations: int, 
    bench: Benchmark
):
    """Wrapper tu run a brax experiment for the SV-CMA-ES algorithm.

    Args:
        key: Random number generator seed.
        npart: Number of particles / ES populations.
        subpopsize: Size of each subpopulation.
        er: Elite ratio for CMA-ES.
        sig: Sigma for CMA-ES.
        kw: Kernel bandwidth.
        nrep: Number or repetitions of the experiments. All experiments will be parallelized.
        num_generations: Number of generations for which the experiments are run for.
        bench: The benchmark class.
    """
    strategy = SV_CMA_BB(
        npart, nsubpop, RBF(kw), num_dims=bench.dim, elite_ratio=er, sigma_init=sig, num_iters=num_generations
    )
    seeds = jax.random.split(key, nrep)
    results = jax.vmap(
        lambda seed: eval_sampling(seed, strategy, bench, num_generations, cb_freq=500)
    )(seeds)
    return results

def run_cfg_oes(key, npart, nsubpop, lr, sig, kw, nrep, num_generations, bench):
    strategy = MC_SVGD(npart, nsubpop, kernel=RBF(kw), num_iters=num_generations, num_dims=bench.dim, sigma_init=sig, lrate_init=lr)
    
    seeds = jax.random.split(key, nrep)
    results = jax.vmap(
        lambda seed: eval_sampling(seed, strategy, bench, num_generations, cb_freq=500)
    )(seeds)
    return results

def run_cfg_og(key, npart, nsubpop, lr, kw, nrep, num_generations, bench):
    strategy = OG_SVGD(npart * nsubpop, RBF(kw), num_iters=num_generations, num_dims=bench.dim, opt_name="adam", lrate_init=lr)
     
    seeds = jax.random.split(key, nrep)
    results = jax.vmap(
        lambda seed: eval_sampling(seed, strategy, bench, num_generations, cb_freq=500)
    )(seeds)
    return results

def run_cfg_gf(key, npart, nsubpop, sig, lr, kw, nrep, num_generations, bench):
    strategy = GF_SVGD(npart * nsubpop, RBF(kw), sigma_init=sig, num_iters=num_generations, num_dims=bench.dim, opt_name="adam", lrate_init=lr)
    
    seeds = jax.random.split(key, nrep)
    results = jax.vmap(
        lambda seed: eval_sampling(seed, strategy, bench, num_generations, cb_freq=500)
    )(seeds)
    return results
