from typing import Literal

import chex
import jax
import jax.numpy as jnp
from evosax import NetworkMapper, ProblemMapper, FitnessShaper, OpenES, CMA_ES

from sves.benchmarks import Benchmark
from sves.kernels import RBF
from sves.strategies import MC_SVGD, ParallelOpenES, SV_CMA_BB, GF_SVGD, ParallelCMAES


class ClassicControl(Benchmark):
    def __init__(self, env_name: str, max_episode_len: int = 1_000):
        self.env_name = env_name
        self.train_evaluator = ProblemMapper["Gymnax"](
            env_name,
            max_episode_len,
            n_devices=1,
            test=False,
        )
        self.test_evaluator = ProblemMapper["Gymnax"](
            env_name,
            max_episode_len,
            n_devices=1,
            test=True,
        )
        self.output_activation = self._activation_mapper(self.env_name)
        super().__init__(
            lb=-jnp.finfo(jnp.float32).max,
            ub=jnp.finfo(jnp.float32).max,
            dim=self.train_evaluator.env.observation_space(self.train_evaluator.env_params).shape,
            fglob=-100,
            name=env_name
        )

    @staticmethod
    def _activation_mapper(env_name: str) -> Literal["categorical", "tanh", "identity"]:
        """Return the correct output activation for the problem name.

        Args:
            env_name: The name of the environment.

        Returns a name of an activation
        """
        if env_name in ["CartPole-v1", "Acrobot-v1"]:
            return "categorical"
        elif env_name in ["MountainCarContinuous-v0", "DeepSea-bsuite"]:
            return "tanh"
        else:
            return "identity"

    def train(
        self,
        rng,
        strategy_cls,
        num_layers=0,
        num_hidden=32,
        num_generations=1_000,
    ):
        rng, network_rng, eval_rng = jax.random.split(rng, 3)

        # Instantiate network
        network = NetworkMapper["MLP"](
            num_hidden_units=num_hidden,
            num_hidden_layers=num_layers,
            num_output_units=self.train_evaluator.env.num_actions,
            hidden_activation="relu",
            output_activation=self.output_activation,
        )
        pholder_params = network.init(network_rng, jnp.zeros(self.dim), rng=network_rng)
        self.train_evaluator.set_apply_fn(network.apply)
        self.test_evaluator.set_apply_fn(network.apply)

        # Get strategy
        shaper = FitnessShaper(centered_rank=True)
        strategy = strategy_cls(pholder_params)

        def run_loop(rng: chex.PRNGKey, num_generations: int):
            if isinstance(strategy, GF_SVGD):
                init_val = jnp.sqrt(6 / (num_hidden + self.train_evaluator.env.observation_space(self.train_evaluator.env_params).shape[0]))
            else:
                init_val = 0
            es_params = strategy.default_params.replace(init_min=-init_val, init_max=init_val)
            if hasattr(strategy.default_params, "alpha_min"):
                es_params = es_params.replace(alpha_min=0.)
            rng, rng_init = jax.random.split(rng)
            es_state = strategy.initialize(rng_init, es_params)
            fit_mean_hist, fit_max_hist = [], []

            for i in range(num_generations):
                rng, rng_a, rng_e = jax.random.split(rng, 3)
                x, es_state = strategy.ask(rng_a, es_state, es_params)

                fitness = self.train_evaluator.rollout(rng_e, x).mean(axis=1)
                fitness = jnp.nan_to_num(fitness, nan=-1e8)  # some envs have errors and return nan
                # Shape fitness for OpenAI-ES methods
                shaped_fitness = shaper.apply(x, fitness) if isinstance(strategy, (MC_SVGD, OpenES, ParallelOpenES)) else fitness
                es_state = strategy.tell(x, -shaped_fitness, es_state, es_params)

                if i == 0 or (i + 1) % 5 == 0:
                    rng, rng_val = jax.random.split(rng)
                    if isinstance(strategy, (SV_CMA_BB, MC_SVGD, GF_SVGD)):
                        means = strategy.param_reshaper.reshape(es_state.particles)  # SV
                    elif isinstance(strategy, (CMA_ES, OpenES)):
                        means = strategy.param_reshaper.reshape(es_state.mean.reshape((1, -1)))  # Single
                    else:
                        means = strategy.base_strategy.param_reshaper.reshape(es_state.mean)  # Parallel

                    mc_fitness = self.test_evaluator.rollout(rng_val, means).mean(axis=1)
                    print(i + 1)
                    print(mc_fitness.max())
                    fit_mean_hist.append(mc_fitness.mean())
                    fit_max_hist.append(mc_fitness.max())

            return (
                jnp.array(fit_mean_hist).squeeze(),
                jnp.array(fit_max_hist).squeeze(),
            )

        fitness_mean, fitness_max = run_loop(eval_rng, num_generations)
        return {
            "fitness_mean": fitness_mean,
            "fitness_max": fitness_max,
        }


def run_cfg_cma(key, npart, subpopsize, er, sig, kw, nrep, num_generations, env_name, ep_len, num_layers=2, num_hidden=16):
    bench = ClassicControl(env_name, ep_len)
    strategy_cls = lambda ph: SV_CMA_BB(npart, subpopsize, RBF(kw), pholder_params=ph, elite_ratio=er, sigma_init=sig,
                                        num_iters=num_generations, n_devices=1)

    seeds = jax.random.split(key, nrep)
    results = jax.vmap(
        lambda seed: bench.train(seed, strategy_cls, num_layers, num_hidden, num_generations)
    )(seeds)

    return results


def run_cfg_oes(key, npart, subpopsize, lr, sig, kw, nrep, num_generations, env_name, ep_len, num_layers=2, num_hidden=16):
    bench = ClassicControl(env_name, ep_len)
    strategy_cls = lambda ph: MC_SVGD(npart, subpopsize, RBF(kw), pholder_params=ph, lrate_init=lr, sigma_init=sig,
                                      num_iters=num_generations, n_devices=1)

    seeds = jax.random.split(key, nrep)
    results = jax.vmap(
        lambda seed: bench.train(seed, strategy_cls, num_layers, num_hidden, num_generations)
    )(seeds)

    return results


def run_cfg_cma_parallel(key, npart, subpopsize, er, sig, nrep, env_name, ep_len, num_layers=2, num_hidden=16, num_generations=200):
    bench = ClassicControl(env_name, ep_len)
    strategy_cls = lambda ph: ParallelCMAES(npart, subpopsize, pholder_params=ph, elite_ratio=er, sigma_init=sig,
                                            n_devices=1)

    seeds = jax.random.split(key, nrep)
    results = jax.vmap(
        lambda seed: bench.train(seed, strategy_cls, num_layers, num_hidden, num_generations)
    )(seeds)

    return results


def run_cfg_oes_parallel(key, npart, subpopsize, lr, sig, nrep, env_name, ep_len, num_layers=2, num_hidden=16, num_generations=200):
    bench = ClassicControl(env_name, ep_len)
    strategy_cls = lambda ph: ParallelOpenES(npart, subpopsize, pholder_params=ph, lrate_init=lr, sigma_init=sig,
                                             n_devices=1)

    seeds = jax.random.split(key, nrep)
    results = jax.vmap(
        lambda seed: bench.train(seed, strategy_cls, num_layers, num_hidden, num_generations)
    )(seeds)

    return results


def run_cfg_cma_single(key, npart, subpopsize, er, sig, nrep, env_name, ep_len, num_layers=2, num_hidden=16, num_generations=200):
    bench = ClassicControl(env_name, ep_len)
    strategy_cls = lambda ph: CMA_ES(npart * subpopsize, pholder_params=ph, elite_ratio=er, sigma_init=sig, n_devices=1)

    seeds = jax.random.split(key, nrep)
    results = jax.vmap(
        lambda seed: bench.train(seed, strategy_cls, num_layers, num_hidden, num_generations)
    )(seeds)

    return results


def run_cfg_oes_single(key, npart, subpopsize, lr, sig, nrep, env_name, ep_len, num_layers=2, num_hidden=16, num_generations=200):
    bench = ClassicControl(env_name, ep_len)
    strategy_cls = lambda ph: OpenES(npart * subpopsize, pholder_params=ph, lrate_init=lr, sigma_init=sig, n_devices=1)

    seeds = jax.random.split(key, nrep)
    results = jax.vmap(
        lambda seed: bench.train(seed, strategy_cls, num_layers, num_hidden, num_generations)
    )(seeds)

    return results


def run_cfg_gf(key, npart, subpopsize, sig, lr, kw, nrep, env_name, ep_len, num_layers=2, num_hidden=16, num_generations=200):
    bench = ClassicControl(env_name, ep_len)
    strategy_cls = lambda ph: GF_SVGD(npart * subpopsize, RBF(kw), pholder_params=ph, sigma_init=sig, lrate_init=lr,
                                      num_iters=num_generations, n_devices=1)

    seeds = jax.random.split(key, nrep)
    results = jax.vmap(
        lambda seed: bench.train(seed, strategy_cls, num_layers, num_hidden, num_generations)
    )(seeds)
    
    return results