import flashbax as fbx
import haiku as hk
import jax
import jax.numpy as jnp
import optax
import pgx
from jax_tqdm import scan_tqdm  # type: ignore
from pgx.experimental import auto_reset
from pydantic import BaseModel

import _mctx as mctx
import src.lib.util as util
import wandb
from src.baselines.alphazero.network import make_network_apply_fns
from src.baselines.alphazero.util import (
    ExItTransition,
    get_train_targets,
    init_model_and_optim,
    make_buffer,
)
from src.util import make_env

devices = jax.local_devices()
num_devices = len(devices)


class Config(BaseModel):
    seed: int = 0
    env_name: pgx.EnvId = "minatar-breakout"  # type: ignore
    use_legal_actions: bool = False
    num_hidden: int = 256
    discount: float = 0.99

    # Training
    num_simulations: int = 32
    lr: float = 5e-3
    min_lr: float = 1e-3  # Minimum learning rate
    lr_linear_decay: bool = True  # Whether to linearly decay the learning rate
    lr_anneal_iterations: int = 200  # Number of iterations to decay the learning rate
    max_grad_norm: float = 5.0
    optim_eps: float = 1e-5
    n_step: int = 5
    target_tau: float = 1.0
    target_update_interval: int = 5  # Update every X epochs
    gumbel_start: float = 1.0
    gumbel_end: float = 1.0
    gumbel_anneal_iterations: int = 30  # Number of iterations to anneal gumbel scale

    # Buffer
    eval_num_actors: int = 1024
    selfplay_batch_size: int = 32
    train_batch_size: int = 1024
    train_epochs_per_iter: int = 20  # For mountain car, this should be 100+
    max_num_steps: int = 128
    total_buffer_size: int = 32 * 8 * 128  # selfplay_batch_size * max_num_steps

    # Placeholders for dynamic values
    num_actions: int = -1
    is_state_vector: bool = False

    # Logging
    eval_interval: int = 5
    max_num_iters: int = 200


# Make env
args = Config()
env, is_state_vector, num_actions = make_env(
    env_name=args.env_name, use_legal_actions=args.use_legal_actions
)
args.num_actions = num_actions
args.is_state_vector = is_state_vector

# Make and initialize the model
prediction_apply = make_network_apply_fns(args)
params, optimizer, opt_state = init_model_and_optim(env, prediction_apply, args=args)
target_params = jax.tree.map(lambda x: jnp.copy(x), params)

# Make trajectory buffer
buffer_fn, buffer_state = make_buffer(env, args)


@jax.jit
def root_fn(params: optax.Params, env_state: pgx.State):
    """Root function for MCTS search."""

    # We auto reset during selfplay, but NOT during recurrent_fn.
    # If this env_state is terminated, and we're using it as a root then we need
    # to manually reset the termination flag here.
    def reset_root(env_state: pgx.State):
        """Reset the root state if it is terminated or truncated."""
        return jax.lax.cond(
            (env_state.terminated | env_state.truncated),
            lambda: env_state.replace(  # type: ignore
                terminated=jnp.bool_(False),
                truncated=jnp.bool_(False),
                rewards=jnp.zeros_like(env_state.rewards),
            ),
            lambda: env_state,
        )

    env_state = jax.vmap(reset_root)(env_state)

    logits, value = prediction_apply.apply(params, env_state.observation)
    return mctx.RootFnOutput(
        prior_logits=logits,  # type: ignore
        value=value,  # type: ignore
        embedding=env_state,  # type: ignore
    )


@jax.jit
def recurrent_fn(
    params: optax.Params,
    rng_key: jnp.ndarray,
    action: jnp.ndarray,
    obs_embedding: pgx.State,  # The env_state
):
    # Split the rng_key across envs
    rng_key, sub_key = jax.random.split(rng_key)
    subkeys = jax.random.split(sub_key, action.shape[0])  # (num_envs, 2)

    # In alphazero search, we step the environment on expansion.
    next_state = jax.vmap(env.step)(obs_embedding, action, subkeys)
    logits, value = prediction_apply.apply(params, next_state.observation)

    reward = next_state.rewards[:, 0]  # (b,)
    value = jnp.where(next_state.terminated, 0.0, value)
    discount = jnp.where(next_state.terminated, 0.0, args.discount)  # (b,)

    return mctx.RecurrentFnOutput(
        prior_logits=logits,  # type: ignore
        value=value,  # type: ignore
        reward=reward,  # type: ignore
        discount=discount,  # type: ignore
    ), next_state


# Define eval loop
@jax.jit
def evaluate(rng_key: jnp.ndarray, params: optax.Params):
    """Evaluate the model by running selfplay and computing the average reward."""
    key, subkey = jax.random.split(rng_key)
    batch_size = args.eval_num_actors
    keys = jax.random.split(subkey, batch_size)
    state = jax.vmap(env.init)(keys)
    step_fn = jax.vmap(env.step)
    step = jnp.array(0)
    max_steps = 12 if "grid" in args.env_name else None

    def cond_fn(
        tup: tuple[jax.Array, pgx.State, jax.Array, jax.Array],
    ) -> jax.Array:
        """Loop while not all envs are done and all below max_steps."""
        _, state, _, step = tup
        still_running = ~state.terminated.all()
        if max_steps is None:
            return still_running
        return jnp.logical_and(still_running, step < max_steps).all()

    def body_fn(val) -> tuple[jax.Array, pgx.State, jax.Array, jax.Array]:
        key, state, R, step = val

        # Initialize the root
        root = root_fn(params, state)

        # Run MCTS search
        search_output = mctx.gumbel_muzero_policy(
            params=params,
            rng_key=key,
            root=root,
            recurrent_fn=recurrent_fn,
            num_simulations=args.num_simulations,
            invalid_actions=~state.legal_action_mask,
            # Revisit: qtransform mix is very greedy
            qtransform=mctx.qtransform_by_parent_and_siblings,
            gumbel_scale=0.0,
            search_fn=mctx.search,
        )
        action = search_output.action  # (b,)

        key, subkey = jax.random.split(key)
        keys = jax.random.split(subkey, batch_size)
        state = step_fn(state, action, keys)

        R = R + state.rewards[:, 0]
        return key, state, R, step + 1

    _, _, R, _ = jax.lax.while_loop(
        cond_fn,
        body_fn,
        (key, state, jnp.zeros((batch_size,)), step),
    )
    return R


# Define selfplay loop
def selfplay(
    rng_key: jax.Array,
    params: optax.Params,
    buffer_state: fbx.trajectory_buffer.TrajectoryBufferState,
    gumbel_scale: jnp.ndarray,
    env_state: pgx.State,
    episode_stats: dict[str, jnp.ndarray],
) -> tuple[
    pgx.State,
    dict[str, jnp.ndarray],
    fbx.trajectory_buffer.TrajectoryBufferState,
    ExItTransition,
]:
    @scan_tqdm(args.max_num_steps)
    def step_fn(
        carry: tuple[
            pgx.State,
            dict[str, jnp.ndarray],
        ],
        iter_data: jnp.ndarray,
    ) -> tuple[tuple[pgx.State, dict[str, jnp.ndarray]], ExItTransition]:
        state, episode_stats = carry
        _, key = iter_data
        key1, key2 = jax.random.split(key)  # (2,), (2,)
        observation = state.observation

        # Initialize the root
        root = root_fn(params, state)

        # Run MCTS search
        search_output = mctx.gumbel_muzero_policy(
            params=params,
            rng_key=key1,
            root=root,
            recurrent_fn=recurrent_fn,
            num_simulations=args.num_simulations,
            invalid_actions=~state.legal_action_mask,
            # Revisit: qtransform mix is very greedy
            qtransform=mctx.qtransform_by_parent_and_siblings,
            gumbel_scale=gumbel_scale,
            search_fn=mctx.search,
        )
        action = search_output.action  # (b,)
        search_policy = search_output.action_weights

        keys = jax.random.split(key2, state.observation.shape[0])
        state = jax.vmap(auto_reset(env.step, env.init))(state, action, keys)

        # Update episode stats
        episode_stats["episode_return"] += jnp.sum(state.rewards, axis=1)
        episode_stats["episode_length"] += 1
        episode_stats["is_terminal_step"] = state.terminated

        # Create transition
        transition = ExItTransition(
            done=state.terminated,  # (b,)
            action=jnp.asarray(action),  # (b,)
            reward=state.rewards[:, -1],  # (b,)
            search_policy=jnp.asarray(search_policy),  # (b, num_actions)
            obs=observation,  # (b, *obs_shape)
            info=episode_stats,
        )

        # Reset stats for terminal steps
        episode_stats = jax.tree_util.tree_map(
            lambda x: jnp.where(state.terminated, jnp.zeros_like(x), x),
            episode_stats,
        )
        return (state, episode_stats), transition

    # Run self-play for max_num_steps per batch
    rng_key, sub_key = jax.random.split(rng_key)
    key_seq = jax.random.split(sub_key, args.max_num_steps)

    (env_state, episode_stats), traj_batch = jax.lax.scan(
        step_fn,  # type: ignore
        (env_state, episode_stats),
        (jnp.arange(args.max_num_steps), key_seq),  # type: ignore
    )

    # Switch the time and batch axes
    traj_batch = jax.tree_util.tree_map(
        lambda x: jnp.swapaxes(x, 0, 1), traj_batch
    )  # (b, t, ...)
    # Add the batch to the buffer
    buffer_state = buffer_fn.add(
        buffer_state,
        traj_batch,
    )  # type: ignore

    return env_state, episode_stats, buffer_state, traj_batch


# Define training loop
def learning_step(params, target_params, opt_state, buffer_state, key):
    """Perform a learning step using the buffer data."""

    def loss_fn(
        params: hk.Params,
        obs: jnp.ndarray,
        policy_target: jnp.ndarray,
        v_target: jnp.ndarray,
    ):
        logits, value = prediction_apply.apply(params, obs)
        policy_loss = optax.softmax_cross_entropy(logits=logits, labels=policy_target)
        policy_loss = jnp.mean(policy_loss)

        valuie_loss = optax.l2_loss(value, v_target)
        value_loss = jnp.mean(valuie_loss)

        return policy_loss + value_loss, {
            "actor_loss": policy_loss,
            "value_loss": value_loss,
        }

    # Sample a batch from the buffer and compute targets
    batch = buffer_fn.sample(buffer_state, key).experience
    (policy_targets, value_targets) = get_train_targets(
        batch,
        target_params,
        prediction_apply,
        args,
    )  # (b * t, num_actions), (b * t,)
    obs = batch.obs[:, 0]

    (loss, losses), grads = jax.value_and_grad(loss_fn, has_aux=True)(
        params,
        obs,
        policy_targets,
        value_targets,
    )
    updates, new_opt_state = optimizer.update(grads, opt_state, params)
    new_params = optax.apply_updates(params=params, updates=updates)
    return new_params, new_opt_state, (loss, losses)


if __name__ == "__main__":

    def train_loop_body(
        carry, iteration
    ):  # -> tuple[tuple[Any, TrajectoryBufferState[Any], Any, Any], Any]:
        (
            rng_key,
            buffer_state,
            opt_state,
            params,
            target_params,
            env_state,
            episode_stats,
        ) = carry
        # Split key for this iteration
        rng_key, subkey = jax.random.split(rng_key)

        def eval_fn():
            R = evaluate(subkey, params)
            jax.debug.print(
                "Iter {i} / {max_num_iters}, Eval Reward: {r}",
                i=iteration,
                max_num_iters=args.max_num_iters,
                r=R.mean(),
            )

        # Run evaluation conditionally
        eval_R = jax.lax.cond(
            iteration % args.eval_interval == 0,
            eval_fn,
            lambda: None,
        )

        # Print/logging (must be done outside JAX or via `jax.debug.print`)
        gumbel_scale = jnp.clip(
            args.gumbel_start
            - (args.gumbel_start - args.gumbel_end)
            * (iteration / args.gumbel_anneal_iterations),
            a_min=args.gumbel_end,
            a_max=args.gumbel_start,
        )
        jax.debug.print(
            "Iteration {i} started. Gumbel scale: {gumbel_scale}, LR: {lr}",
            i=iteration,
            gumbel_scale=gumbel_scale,
            lr=util.linear_schedule(
                iteration,
                args.lr,
                args.lr_anneal_iterations,
                args.min_lr,
            ),
        )

        # Self-play data collection
        env_state, episode_stats, buffer_state, traj_batch = selfplay(
            subkey, params, buffer_state, gumbel_scale, env_state, episode_stats
        )

        # Log the training stats
        episode_returns = traj_batch.info["episode_return"] * traj_batch.info[
            "is_terminal_step"
        ].astype(jnp.float32)  # type: ignore
        episode_lengths = traj_batch.info["episode_length"] * traj_batch.info[
            "is_terminal_step"
        ].astype(jnp.int32)  # type: ignore
        total_terminations = jnp.sum(
            traj_batch.info["is_terminal_step"].astype(jnp.int32)  # type: ignore
        )
        average_return = jnp.sum(episode_returns) / (total_terminations + 1e-8)
        average_length = jnp.sum(episode_lengths) / (total_terminations + 1e-8)
        jax.debug.print(
            "Iter {i} / {max_num_iters}, Train Avg Return: {avg_return:.2f}, Avg Length: {avg_length:.2f}",
            i=iteration,
            max_num_iters=args.max_num_iters,
            avg_return=average_return,
            avg_length=average_length,
        )

        # Training step
        rng_key, subkey = jax.random.split(rng_key)

        # Learning step for epochs
        def scan_fn(carry, epoch):
            params, target_params, opt_state, rng_key = carry
            rng_key, subkey = jax.random.split(rng_key)
            params, opt_state, (loss, losses) = learning_step(
                params, target_params, opt_state, buffer_state, subkey
            )

            # Update target parameters
            target_params = jax.lax.cond(
                epoch % args.target_update_interval == 0,
                lambda: jax.tree_util.tree_map(
                    lambda target, online: args.target_tau * online
                    + (1 - args.target_tau) * target,
                    target_params,
                    params,
                ),
                lambda: target_params,
            )

            return (params, target_params, opt_state, rng_key), (loss, losses)

        (params, target_params, opt_state, _), (loss, losses) = jax.lax.scan(
            scan_fn,
            (params, target_params, opt_state, subkey),
            jnp.arange(args.train_epochs_per_iter),
        )
        # Log the losses
        jax.debug.print(
            "Iter {i} / {max_num_iters}, Loss: {loss:.4f}, Actor Loss: {actor_loss:.4f}, "
            "Value Loss: {value_loss:.4f}",
            i=iteration,
            max_num_iters=args.max_num_iters,
            loss=jnp.mean(loss),
            actor_loss=jnp.mean(losses["actor_loss"]),
            value_loss=jnp.mean(losses["value_loss"]),
        )

        carry = (
            rng_key,
            buffer_state,
            opt_state,
            params,
            target_params,
            env_state,
            episode_stats,
        )
        return carry, eval_R

    # Run scan
    rng_key = jax.random.PRNGKey(seed=args.seed)
    init_rng_key, sub_key = jax.random.split(rng_key)
    keys = jax.random.split(sub_key, args.selfplay_batch_size)
    env_state = jax.vmap(env.init)(keys)

    episode_stats_init = {
        "episode_return": jnp.zeros((args.selfplay_batch_size,)),
        "episode_length": jnp.zeros((args.selfplay_batch_size,), dtype=jnp.int32),
        "is_terminal_step": jnp.zeros((args.selfplay_batch_size,), dtype=bool),
    }

    initial_carry = (
        init_rng_key,
        buffer_state,
        opt_state,
        params,
        target_params,
        env_state,
        episode_stats_init,
    )
    iterations = jnp.arange(args.max_num_iters)

    final_carry, eval_rewards = jax.lax.scan(train_loop_body, initial_carry, iterations)
