import jax
import chex
import wandb
from typing import Any
import jax.numpy as jnp

from meta.meta import ESTrainState
from models.mirror_map import param_tuner as param_tuner_init


def es_train_step(
    rng: chex.PRNGKey,
    mmap_train_state: ESTrainState,
    n_devices: int,
    contrained_net: bool = True,
    candidate_fitness_fun: Any = None,
    args: Any = None,
):
    """
    Train a batch of agents with BORPO, then update the mirror map with ES.
    Uses antithetic task sampling, meaning each antithetic pair of ES candidates is evaluated on the same dataset.
    """
    param_tuner = lambda u: param_tuner_init(u, single=False)
    param_tuner_tell = lambda u: param_tuner_init(
        u, single=True, floor=-args.es_sigma_init
    )


    # --- Generate ES candidates ---
    rng, _rng = jax.random.split(rng)
    candidate_params, es_state = jax.jit(mmap_train_state.strategy.ask)(
        _rng, mmap_train_state.es_state, mmap_train_state.es_params
    )

    # --- Reshaping for multi-device evaluation ---
    candidate_params_reshaped = jax.tree_map(
        lambda x: x.reshape(
            (mmap_train_state.strategy.popsize // 2, 2) + x.shape[1:],
            order="F",
        ),
        candidate_params,
    )

    candidate_params_reshaped = jax.tree_map(
        lambda x: x.reshape(
            (n_devices, mmap_train_state.strategy.popsize // n_devices // 2)
            + x.shape[1:],
            order="F",
        ),
        candidate_params_reshaped,
    )

    if contrained_net:
        candidate_params_reshaped = param_tuner(candidate_params_reshaped)

    # --- Evaluate LPMD candidates ---
    eval_rngs = jax.random.split(
        rng, (n_devices, mmap_train_state.strategy.popsize // n_devices // 2)
    )
    fitness, base_returns = candidate_fitness_fun(
        eval_rngs, candidate_params_reshaped
    )

    # --- Reshaping back to single device ---
    fitness = fitness.reshape((-1,), order="F")

    if args.rank_transform == 1:
        # --- Compute rank transformation per antithetic pair ---
        first_greater = jnp.greater(*jnp.split(fitness, 2))
        rank_fitness = jnp.zeros_like(fitness)
        rank_fitness = rank_fitness.at[0 : (fitness.shape[0] // 2)].set(
            first_greater.astype(float)
        )
        rank_fitness = rank_fitness.at[(fitness.shape[0] // 2) :].set(
            1.0 - first_greater.astype(float)
        )
    else:
        rank_fitness = (fitness - jnp.mean(fitness)) / (jnp.std(fitness) + 1e-6)

    # --- Update and return ES state ---
    new_es_state = jax.jit(mmap_train_state.strategy.tell)(
        candidate_params, rank_fitness, es_state, mmap_train_state.es_params
    )
    new_es_state = new_es_state.replace(
        mean=mmap_train_state.strategy.param_reshaper.flatten_single(
            param_tuner_tell(
                mmap_train_state.strategy.param_reshaper.reshape_single(
                    new_es_state.mean
                ),
            )
        )
    )
    mmap_train_state = mmap_train_state.replace(es_state=new_es_state)

    # --- Log metrics ---
    metrics = {
        "fitness": {
            "mean": jnp.mean(fitness),
            "base_return": base_returns.mean(),
            "min": jnp.min(fitness),
            "max": jnp.max(fitness),
            "distr": wandb.Histogram(fitness),
        },
    }
    return mmap_train_state, metrics