from types import MethodType
from functools import partial
from typing import Callable

import chex
import jax
import jax.numpy as jnp
from evosax import NetworkMapper, FitnessShaper, CMA_ES, OpenES
from brax.v1.envs import create

from sves.benchmarks import Benchmark
from sves.benchmarks.ablated_brax_step import ablated_step
from sves.strategies import MC_SVGD, ParallelOpenES, GF_SVGD, SV_CMA_BB, ParallelCMAES
from sves.kernels import RBF


class BraxControl(Benchmark):
    """Evaluation class to train a brax agent."""
    def __init__(self, env_name: str, max_episode_len: int = 1_000):
        self.env_name = env_name
        self.max_steps = max_episode_len
        brax_env_name = env_name.split("_")[0] if env_name.endswith("ablated") else env_name
        self.env = create(
            env_name=brax_env_name,
            episode_length=max_episode_len,
            legacy_spring=True,
        )
        super().__init__(
            lb=-jnp.finfo(jnp.float32).max,
            ub=jnp.finfo(jnp.float32).max,
            dim=self.env.observation_size,
            fglob=-100,
            name=env_name
        )

        # Monkey patch the step function which contains the reward
        if env_name.endswith("ablated"):
            match brax_env_name:
                case "hopper":
                    self.env.step = MethodType(ablated_step, self.env)
                case "walker2d":
                    self.env.step = MethodType(ablated_step, self.env)
                case _:
                    raise NotImplementedError(f"{brax_env_name}_ablated not implemented!")

    def brax_eval_fn(self, network, obs_norm, rng, params, obs_params):
        brax_state = self.env.reset(rng)
        valid_mask = jnp.ones((1,))
        acc_return = jnp.array([0.0])

        def env_step(state_input, tmp):
            brax_state, valid_mask, acc_return = state_input
            original_obs = brax_state.obs
            normed_obs = obs_norm(original_obs, obs_params)
            action = network.apply(params, normed_obs)
            brax_state = self.env.step(brax_state, action)
            acc_return = acc_return + brax_state.reward * valid_mask
            valid_mask = valid_mask * (1 - brax_state.done.ravel())
            return (brax_state, valid_mask, acc_return), (valid_mask, original_obs)

        carry_out, scan_out = jax.lax.scan(
            env_step, (brax_state, valid_mask, acc_return), (), self.max_steps
        )
        cum_reward = carry_out[2]
        mask_buffer = scan_out[0]
        obs_buffer = scan_out[1]
        return rng, cum_reward.squeeze(), mask_buffer, obs_buffer

    def train(
        self,
        rng: chex.PRNGKey,
        strategy_cls: Callable,
        num_layers: int = 0,
        num_hidden: int = 32,
        num_generations: int = 1_000,
        num_mc: int = 16
    ):
        """Run the experiment.

        rng: Random number generator seed.
        strategy_cls:
            Strategy class.
            Must be a function of 'placeholder_params' so it can be initialized with the network parameters.
        num_layers: Number of hidden layers in the neural network.
        num_hidden: Number of units per hidden layer.
        num_generations: Number of generations.
        num_mc: Number of MC samples that are used to estimate the expected return.
        """
        rng, network_rng, eval_rng = jax.random.split(rng, 3)

        # Instantiate network
        network = NetworkMapper["MLP"](
            num_hidden_units=num_hidden,
            num_hidden_layers=num_layers,
            num_output_units=self.env.action_size,
            hidden_activation="relu",
            output_activation="tanh",
        )
        pholder_params = network.init(network_rng, jnp.zeros((1, self.env.observation_size)))

        # Get strategy
        shaper = FitnessShaper(centered_rank=True)
        strategy = strategy_cls(pholder_params)

        # Instantiate eval functionalities
        normalize = partial(self.normalize_obs, obs_shape=(self.env.observation_size,))
        eval_fn = jax.jit(
            partial(self.brax_eval_fn, network, normalize)
        )
        batch_eval = jax.vmap(eval_fn, in_axes=(None, 0, None))

        def run_loop(rng: chex.PRNGKey, num_generations: int):
            if isinstance(strategy, GF_SVGD):
                init_val = jnp.sqrt(6 / (num_hidden + self.env.observation_size))
            else:
                init_val = 0
            es_params = strategy.default_params.replace(init_min=-init_val, init_max=init_val)
            if hasattr(strategy.default_params, "alpha_min"):
                es_params = es_params.replace(alpha_min=0.)
            obs_params = jnp.zeros(1 + jnp.prod(jnp.array(self.env.observation_size)) * 2)

            rng, rng_init = jax.random.split(rng)
            es_state = strategy.initialize(rng_init, es_params)
            fit_mean_hist, fit_max_hist = [], []

            for i in range(num_generations):
                rng, rng_a, rng_e = jax.random.split(rng, 3)
                x, es_state = strategy.ask(rng_a, es_state, es_params)

                rng_e, fitness, mask_buffer, obs_buffer = batch_eval(rng_e, x, obs_params)
                fitness = jnp.nan_to_num(fitness, nan=-1e8)  # adjust to not break when env has errors: https://github.com/google/brax/issues/467
                # Shape fitness for OpenAI-ES methods
                if isinstance(strategy, (MC_SVGD, OpenES, ParallelOpenES)):
                    shaped_fitness = shaper.apply(x, fitness)
                else:
                    shaped_fitness = fitness
                es_state = strategy.tell(x, -shaped_fitness, es_state, es_params)
                obs_params = self.update_obs_params(obs_buffer, mask_buffer, obs_params)

                # Evaluate performance
                if i == 0 or (i + 1) % 20 == 0:
                    rng, rng_val = jax.random.split(rng)
                    if isinstance(strategy, (SV_CMA_BB, MC_SVGD, GF_SVGD)):
                        means = strategy.param_reshaper.reshape(es_state.particles)  # SV
                    elif isinstance(strategy, (CMA_ES, OpenES)):
                        means = strategy.param_reshaper.reshape(es_state.mean.reshape((1, -1)))  # Single
                    else:
                        means = strategy.base_strategy.param_reshaper.reshape(es_state.mean)  # Parallel

                    # MC estimate over n seeds using vmap
                    rng_vals = jax.random.split(rng_val, num_mc)
                    mc_fitness = (jax.vmap(lambda rng: batch_eval(rng, means, obs_params)[1])(rng_vals)).mean(axis=0)
                    fitness_mean = mc_fitness.mean()
                    fitness_max = mc_fitness.max()

                    print(i + 1)
                    print(fitness_max)
                    fit_mean_hist.append(fitness_mean)
                    fit_max_hist.append(fitness_max)

            return (
                jnp.array(fit_mean_hist).squeeze(),
                jnp.array(fit_max_hist).squeeze(),
            )

        fitness_mean, fitness_max = run_loop(eval_rng, num_generations)
        return {
            "fitness_mean": fitness_mean,
            "fitness_max": fitness_max,
        }

    @staticmethod
    def normalize_obs(
        obs: jnp.ndarray,
        obs_params: jnp.ndarray,
        obs_shape: tuple,
        clip_value: float = 5.0,
        std_min_value: float = 1e-6,
        std_max_value: float = 1e6,
    ) -> jnp.ndarray:
        """Normalize the given observation."""

        obs_steps = obs_params[0]
        running_mean, running_var = jnp.split(obs_params[1:], 2)
        running_mean = running_mean.reshape(obs_shape)
        running_var = running_var.reshape(obs_shape)

        variance = running_var / (obs_steps + 1.0)
        variance = jnp.clip(variance, std_min_value, std_max_value)
        return jnp.clip((obs - running_mean) / jnp.sqrt(variance), -clip_value, clip_value)

    @staticmethod
    def update_obs_params(
        obs_buffer: jnp.ndarray, obs_mask: jnp.ndarray, obs_params: jnp.ndarray
    ) -> jnp.ndarray:
        """Update observation normalization parameters."""

        obs_steps = obs_params[0]
        running_mean, running_var = jnp.split(obs_params[1:], 2)
        if obs_mask.ndim != obs_buffer.ndim:
            obs_mask = obs_mask.reshape(
                obs_mask.shape + (1,) * (obs_buffer.ndim - obs_mask.ndim)
            )

        new_steps = jnp.sum(obs_mask)
        total_steps = obs_steps + new_steps

        input_to_old_mean = (obs_buffer - running_mean) * obs_mask
        mean_diff = jnp.sum(input_to_old_mean / total_steps, axis=(0, 1))
        new_mean = running_mean + mean_diff

        input_to_new_mean = (obs_buffer - new_mean) * obs_mask
        var_diff = jnp.sum(input_to_new_mean * input_to_old_mean, axis=(0, 1))
        new_var = running_var + var_diff

        return jnp.concatenate([jnp.ones(1) * total_steps, new_mean, new_var])


### ---- Wrappers to run code in jupyter
def run_cfg_cma(
    key: chex.PRNGKey,
    npart: int,
    subpopsize: int,
    er: float,
    sig: float,
    kw: float,
    nrep: int,
    num_generations: int,
    env_name: str,
    num_layers: int = 2,
    num_hidden: int = 16
):
    """Wrapper to run a brax experiment for the SV-CMA-ES algorithm.

    Args:
        key: Random number generator seed.
        npart: Number of particles / ES populations.
        subpopsize: Size of each subpopulation.
        er: Elite ratio for CMA-ES.
        sig: Sigma for CMA-ES.
        kw: Kernel bandwidth.
        nrep: Number or repetitions of the experiments. All experiments will be parallelized.
        num_generations: Number of generations for which the experiments are run for.
        env_name:
            The name of the environnment. For the paper we use
            'hopper', 'walker2d', 'halfcheetah', 'hopper_ablated', and 'walker2d_ablated'.
        num_layers: Number of hidden layers in the neural network.
        num_hidden: Number of units per hidden layer.
    """

    bench = BraxControl(env_name)
    strategy_cls = lambda ph: SV_CMA_BB(
        npart, subpopsize, RBF(kw), pholder_params=ph,
        elite_ratio=er, sigma_init=sig, num_iters=num_generations, n_devices=1
    )
    
    seeds = jax.random.split(key, nrep)
    results = jax.vmap(
        lambda seed: bench.train(seed, strategy_cls, num_layers, num_hidden, num_generations)
    )(seeds)

    return results


def run_cfg_oes(
    key, npart, subpopsize, lr, sig, kw, nrep, num_generations, env_name, num_layers=2, num_hidden=16
) -> dict[str, chex.Array]:
    bench = BraxControl(env_name)
    strategy_cls = lambda ph: MC_SVGD(npart, subpopsize, RBF(kw), pholder_params=ph, lrate_init=lr, sigma_init=sig, num_iters=num_generations, n_devices=1)
    
    seeds = jax.random.split(key, nrep)
    results = jax.vmap(
        lambda seed: bench.train(seed, strategy_cls, num_layers, num_hidden, num_generations)
    )(seeds)
    
    return results


def run_cfg_cma_parallel(
    key, npart, subpopsize, er, sig, nrep, env_name, num_layers=2, num_hidden=16, num_generations=500
) -> dict[str, chex.Array]:
    bench = BraxControl(env_name)
    strategy_cls = lambda ph: ParallelCMAES(npart, subpopsize, pholder_params=ph, elite_ratio=er, sigma_init=sig, n_devices=1)
    
    seeds = jax.random.split(key, nrep)
    results = jax.vmap(
        lambda seed: bench.train(seed, strategy_cls, num_layers, num_hidden, num_generations)
    )(seeds)
    
    return results


def run_cfg_oes_parallel(
    key, npart, subpopsize, lr, sig, nrep, env_name, num_layers=2, num_hidden=16, num_generations=500
) -> dict[str, chex.Array]:
    bench = BraxControl(env_name)
    strategy_cls = lambda ph: ParallelOpenES(npart, subpopsize, pholder_params=ph, lrate_init=lr, sigma_init=sig, n_devices=1)
    
    seeds = jax.random.split(key, nrep)
    results = jax.vmap(
        lambda seed: bench.train(seed, strategy_cls, num_layers, num_hidden, num_generations)
    )(seeds)
    
    return results


def run_cfg_cma_single(
    key, npart, subpopsize, er, sig, nrep, env_name, num_layers=2, num_hidden=16, num_generations=500
) -> dict[str, chex.Array]:
    bench = BraxControl(env_name)
    strategy_cls = lambda ph: CMA_ES(npart * subpopsize, pholder_params=ph, elite_ratio=er, sigma_init=sig, n_devices=1)
    
    seeds = jax.random.split(key, nrep)
    results = jax.vmap(
        lambda seed: bench.train(seed, strategy_cls, num_layers, num_hidden, num_generations)
    )(seeds)
    
    return results


def run_cfg_oes_single(
    key, npart, subpopsize, lr, sig, nrep, env_name, num_layers=2, num_hidden=16, num_generations=500
) -> dict[str, chex.Array]:
    bench = BraxControl(env_name)
    strategy_cls = lambda ph: OpenES(npart * subpopsize, pholder_params=ph, lrate_init=lr, sigma_init=sig, n_devices=1)
    
    seeds = jax.random.split(key, nrep)
    results = jax.vmap(
        lambda seed: bench.train(seed, strategy_cls, num_layers, num_hidden, num_generations)
    )(seeds)
    
    return results


def run_cfg_gf(
    key, npart, subpopsize, lr, sig, kw, nrep, env_name, num_layers=2, num_hidden=16, num_generations=500
) -> dict[str, chex.Array]:
    bench = BraxControl(env_name)
    strategy_cls = lambda ph: GF_SVGD(npart * subpopsize, RBF(kw), sigma_init=sig, pholder_params=ph, lrate_init=lr, num_iters=num_generations, n_devices=1)
    
    seeds = jax.random.split(key, nrep)
    results = jax.vmap(
        lambda seed: bench.train(seed, strategy_cls, num_layers, num_hidden, num_generations)
    )(seeds)
    
    return results
