import yaml
import chex
import jax
import jax.numpy as jnp
import orbax.checkpoint as obcheckpoint
from functools import partial
from flax import struct
from evosax import Strategy
from typing import Optional
from flax.training.train_state import TrainState

from dataset_gen.dataset_gen import gen_dataset
from envs.rollout import rollout
from models.agents import ActorCritic as Actor
from models.mirror_map import MPO_map
from models.optim import create_optimizer, create_es_strategy
from rlhf_agents.train import make_train as make_rlhf_train
from envs.env_utils import get_env
from utils.jax import mini_batch_pmap


@struct.dataclass
class OptParams:
    lrate_init: float = 0.01
    lrate_decay: float = 0.999
    lrate_limit: float = 0.001
    momentum: Optional[float] = None
    beta_1: Optional[float] = None
    beta_2: Optional[float] = None
    beta_3: Optional[float] = None
    eps: Optional[float] = None
    max_speed: Optional[float] = None


@struct.dataclass
class EvoParams:
    opt_params: OptParams
    sigma_init: float = 0.04
    sigma_decay: float = 0.999
    sigma_limit: float = 0.01
    init_min: float = 0.0
    init_max: float = 0.0
    clip_min: float = -jnp.finfo(jnp.float32).max
    clip_max: float = jnp.finfo(jnp.float32).max


@struct.dataclass
class OptState:
    lrate: float
    m: chex.Array
    v: Optional[chex.Array] = None
    n: Optional[chex.Array] = None
    last_grads: Optional[chex.Array] = None
    gen_counter: int = 0


@struct.dataclass
class EvoState:
    mean: chex.Array
    sigma: chex.Array
    opt_state: OptState
    best_member: chex.Array
    best_fitness: float = jnp.finfo(jnp.float32).max
    gen_counter: int = 0


class ESTrainState(struct.PyTreeNode):
    """Extension of the Flax TrainState class for EvoSax agents"""

    train_state: TrainState = struct.field(pytree_node=True)
    strategy: Strategy = struct.field(pytree_node=False)
    es_params: EvoParams = struct.field(pytree_node=True)
    es_state: EvoState = struct.field(pytree_node=True)


def create_mmap_train_state(rng, args):
    """
    Initialises an LPMD instance.
    Returns ESTrainState.
    """
    mirror_map = MPO_map(
        num_hidden_units=args.mmap_net_width,
        temporally_aware=args.temporally_aware,
        parametrised_reward_model=args.parametrised_reward_model,
        add_logsimoid_bias=args.add_logsimoid_bias,
        add_sft_bias=args.add_sft_bias,
        add_dpo_bias=args.add_dpo_bias,
        sft_term=args.sft_term,
    )
    params = mirror_map.init(rng, 1.0)["params"]

    tx = create_optimizer(args.lpmd_opt, args.es_lrate_init, args.lpmd_max_grad_norm)
    train_state = TrainState.create(apply_fn=mirror_map.apply, params=params, tx=tx)

    es_strategy = create_es_strategy(args, train_state.params)
    es_params = (
        es_strategy.default_params
    )  # these are the es optimisation params TODO: are these the parameters we are setting or the default ones? from the code it looks like they are our parameters
    es_state = es_strategy.initialize(rng, es_params, init_mean=train_state.params)
    if args.es_checkpoint is not None:
        orbax_checkpointer = obcheckpoint.PyTreeCheckpointer()
        checkpoint = orbax_checkpointer.restore(args.es_checkpoint)
        es_state = EvoState(**checkpoint["es_state"])
        es_state = es_state.replace(
            opt_state=OptState(**checkpoint["es_state"]["opt_state"])
        )
        train_state = train_state.replace(**checkpoint["train_state"])
        opt_params = OptParams(**checkpoint["es_params"]["opt_params"])
        del checkpoint["es_params"]["opt_params"]
        es_params = EvoParams(opt_params=opt_params, **checkpoint["es_params"])
    return ESTrainState(train_state, es_strategy, es_params, es_state)


def get_candidate_fitness_fun(args, mmap_train_state, dataset):

    # --- Initialize env ---
    env, env_params, config_env = get_env(args.env_name,backend="positional", indeces=args.indeces)

    # --- Load config ---
    path = args.main_folder_path + "atari_rlhf/config.yaml"
    with open(path, "r") as file:
        rlhf_config = yaml.safe_load(file)
    rlhf_config.update(
        {
            "ACTIVATION": config_env["ACTIVATION"],
            "DATASET_SIZE": args.num_data_points,
            "LOSS_TYPE": "mpo",
            "TRACKING": False,
            "STAGES": args.n_stages,
        },
        allow_val_change=True,
    )

    # Load ref agent
    if args.reference_agent is None:
        ref_agent = None
    else:
        ref_agent = jnp.load(args.reference_agent, allow_pickle=True)

    # --- Initialize candidate fitness fun ---
    # Train function
    agent_train_fn = make_rlhf_train(config=rlhf_config, env=env, env_params=env_params)
    # Make eval function
    network = Actor(
        env.num_actions,
        activation=config_env["ACTIVATION"],
    )
    rollout_eval = lambda x, y: rollout(
        agent_params=x,
        rng=y,
        num_envs=args.num_eval_agents,
        num_steps=1000,
        env=env,
        env_params=env_params,
        return_reward=True,
        without_restart=True,
        network=network,
    )

    def sample_dataset(rng, dataset, num_samples=1):
        indices = jax.random.randint(rng, (num_samples,), 0, dataset.obs.shape[0])
        return jax.tree_map(lambda x: x[indices], dataset)

    def _compute_candidate_fitness(rng, candidate_params_pair, preferences):
        """Train and evaluate an agent with an LPMD parameter candidate."""

        rng_data, rng_train, rng_eval = jax.random.split(rng, 3)

        # Genereate dataset
        preferences = sample_dataset(rng_data, preferences, args.num_data_points)

        # --- Evaluate candidate parameters ---
        # Train
        rngs = jax.random.split(rng_train)
        train_state, _ = jax.vmap(agent_train_fn, in_axes=(0, None, None, 0, None))(
            rngs,
            preferences,
            mmap_train_state,
            candidate_params_pair,
            start_agent=ref_agent,
        )

        # --- Compute return of trained agent ---
        _, _, all_rewards, _ = jax.vmap(rollout_eval, in_axes=(0, None))(
            train_state.params,
            rng_eval,
        )

        # pi = jax.random.uniform(rng, (1000,)) + 0.001
        # loss = jax.vmap(
        #     lambda x: jnp.square(
        #         jax.vmap(mmap_train_state.train_state.apply_fn, in_axes=(None, 0))(
        #             {"params": x}, pi
        #         )
        #         - jnp.sqrt(pi)
        #     ).mean()
        # )(candidate_params_pair)

        return (
            # train_state,
            all_rewards.mean(axis=(1, 2)) * 1000,  # -loss,  #
            preferences.tot_rewards.mean(),
        )

    compute_candidate_fitness = partial(_compute_candidate_fitness, preferences=dataset)

    return mini_batch_pmap(compute_candidate_fitness, args.num_mini_batches)
