import flashbax as fbx
import haiku as hk
import jax
import jax.numpy as jnp
import optax
import pgx
import wandb
from jax_tqdm import scan_tqdm  # type: ignore
from pgx.experimental import auto_reset
from pydantic import BaseModel

import _mctx as mctx
import src.lib.util as util
from src.baselines.muzero.networks import make_network_apply_fns
from src.baselines.muzero.util import (
    ExItTransition,
    get_train_targets,
    init_model_and_optim,
    make_buffer,
    scale_gradient,
)
from src.util import make_env


class Config(BaseModel):
    seed: int = 0
    env_name: pgx.EnvId = "grid"  # type: ignore
    use_legal_actions: bool = False
    num_hidden: int = 256
    discount: float = 0.99

    # Training
    num_simulations: int = 32
    vf_coeff: float = 0.25
    sc_coeff: float = 2.0
    lr: float = 5e-3
    min_lr: float = 1e-3  # Minimum learning rate
    lr_linear_decay: bool = True  # Whether to linearly decay the learning rate
    lr_anneal_iterations: int = 200  # Number of iterations to decay the learning rate
    max_grad_norm: float = 5.0
    optim_eps: float = 1e-5
    n_step: int = 10
    target_tau: float = 1.0
    target_update_interval: int = 5  # Update every X epochs
    gumbel_start: float = 5.0
    gumbel_end: float = 1.0
    gumbel_anneal_iterations: int = 30  # Number of iterations to anneal gumbel scale

    huber_param: float = 1.0
    num_quantiles: int = 200

    # Buffer
    eval_num_actors: int = 1024
    selfplay_batch_size: int = 32
    train_batch_size: int = 1024
    train_epochs_per_iter: int = 20  # For mountain car, this should be 100+
    sample_sequence_length: int = 6
    max_num_steps: int = 128
    total_buffer_size: int = 32 * 8 * 128  # selfplay_batch_size * max_num_steps

    # Placeholders for dynamic values
    num_actions: int = -1
    is_state_vector: bool = False

    # Logging
    eval_interval: int = 5
    max_num_iters: int = 200


# Make env
args = Config()
env, is_state_vector, num_actions = make_env(
    env_name=args.env_name, use_legal_actions=args.use_legal_actions
)
args.num_actions = num_actions
args.is_state_vector = is_state_vector

# Make and initialize the model
(
    init_model,
    representation_apply,
    projection_apply,
    policy_apply,
    critic_apply,
    recurrent_inference,
) = make_network_apply_fns(args)
params, optimizer, opt_state = init_model_and_optim(env, init_model, args=args)
target_params = jax.tree.map(lambda x: jnp.copy(x), params)

# Make trajectory buffer
buffer_fn, buffer_state = make_buffer(env, args)


# Define eval loop
@jax.jit
def evaluate(rng_key: jnp.ndarray, params: optax.Params):
    """Evaluate the model by running selfplay and computing the average reward."""
    key, subkey = jax.random.split(rng_key)
    batch_size = args.eval_num_actors
    keys = jax.random.split(subkey, batch_size)
    state = jax.vmap(env.init)(keys)
    step_fn = jax.vmap(env.step)
    step = jnp.array(0)
    max_steps = 12 if "grid" in args.env_name else None

    def cond_fn(
        tup: tuple[jax.Array, pgx.State, jax.Array, jax.Array],
    ) -> jax.Array:
        """Loop while not all envs are done and all below max_steps."""
        _, state, _, step = tup
        still_running = ~state.terminated.all()
        if max_steps is None:
            return still_running
        return jnp.logical_and(still_running, step < max_steps).all()

    def body_fn(val) -> tuple[jax.Array, pgx.State, jax.Array, jax.Array]:
        key, state, R, step = val

        # Initialize the root
        root = root_fn(params, state.observation)

        # Run MCTS search
        search_output = mctx.gumbel_muzero_policy(
            params=params,
            rng_key=key,
            root=root,
            recurrent_fn=recurrent_fn,
            num_simulations=args.num_simulations,
            invalid_actions=~state.legal_action_mask,
            # Revisit: qtransform mix is very greedy
            qtransform=mctx.qtransform_by_parent_and_siblings,
            gumbel_scale=0.0,
            search_fn=mctx.search,
        )
        action = search_output.action  # (b,)

        key, subkey = jax.random.split(key)
        keys = jax.random.split(subkey, batch_size)
        state = step_fn(state, action, keys)

        R = R + state.rewards[:, 0]
        return key, state, R, step + 1

    _, _, R, _ = jax.lax.while_loop(
        cond_fn,
        body_fn,
        (key, state, jnp.zeros((batch_size,)), step),
    )
    return R


@jax.jit
def root_fn(params: optax.Params, obs: jnp.ndarray):
    """Root function for MCTS search."""
    obs_embedding = representation_apply.apply(params, obs)
    logits = policy_apply.apply(params, obs_embedding)
    logits = logits - jnp.max(logits, axis=-1, keepdims=True)
    value = critic_apply.apply(params, obs_embedding)
    return mctx.RootFnOutput(
        prior_logits=logits,  # type: ignore
        value=value,  # type: ignore
        embedding=obs_embedding,  # type: ignore
    )


@jax.jit
def recurrent_fn(
    params: optax.Params,
    _rng_key: jnp.ndarray,
    action: jnp.ndarray,
    obs_embedding: jnp.ndarray,
):
    next_obs_embedding, reward = recurrent_inference.apply(
        params, obs_embedding, action
    )
    logits = policy_apply.apply(params, next_obs_embedding)
    logits = logits - jnp.max(logits, axis=-1, keepdims=True)
    value = critic_apply.apply(params, next_obs_embedding)

    return mctx.RecurrentFnOutput(
        prior_logits=logits,  # type: ignore
        value=value,  # type: ignore
        reward=reward,  # type: ignore
        discount=jnp.ones_like(reward) * args.discount,  # type: ignore
    ), next_obs_embedding


# Define selfplay loop
def selfplay(
    rng_key: jax.Array,
    params: optax.Params,
    buffer_state: fbx.trajectory_buffer.TrajectoryBufferState,
    gumbel_scale: jnp.ndarray,
    env_state: pgx.State,
    episode_stats: dict[str, jnp.ndarray],
) -> tuple[
    pgx.State,
    dict[str, jnp.ndarray],
    fbx.trajectory_buffer.TrajectoryBufferState,
    ExItTransition,
]:
    @scan_tqdm(args.max_num_steps)
    def step_fn(
        carry: tuple[
            pgx.State,
            dict[str, jnp.ndarray],
        ],
        iter_data: jnp.ndarray,
    ) -> tuple[tuple[pgx.State, dict[str, jnp.ndarray]], ExItTransition]:
        state, episode_stats = carry
        _, key = iter_data
        key1, key2 = jax.random.split(key)  # (2,), (2,)
        observation = state.observation

        # Initialize the root
        root = root_fn(params, observation)

        # Run MCTS search
        search_output = mctx.gumbel_muzero_policy(
            params=params,
            rng_key=key1,
            root=root,
            recurrent_fn=recurrent_fn,
            num_simulations=args.num_simulations,
            invalid_actions=~state.legal_action_mask,
            # Revisit: qtransform mix is very greedy
            qtransform=mctx.qtransform_by_parent_and_siblings,
            gumbel_scale=gumbel_scale,
            search_fn=mctx.search,
        )
        action = search_output.action  # (b,)
        search_policy = search_output.action_weights

        keys = jax.random.split(key2, state.observation.shape[0])
        state = jax.vmap(auto_reset(env.step, env.init))(state, action, keys)

        # jax.debug.print(
        #     "State: {state}, Action: {action}, Search Policy: {search_policy}",
        #     state=state.observation[0:2],  # type: ignore
        #     action=action[0:2],  # type: ignore
        #     search_policy=search_policy[0:2],  # type: ignore
        # )

        # Update episode stats
        episode_stats["episode_return"] += jnp.sum(state.rewards, axis=1)
        episode_stats["episode_length"] += 1
        episode_stats["is_terminal_step"] = state.terminated

        # Create transition
        transition = ExItTransition(
            done=state.terminated,  # (b,)
            action=jnp.asarray(action),  # (b,)
            reward=state.rewards[:, -1],  # (b,)
            search_policy=jnp.asarray(search_policy),  # (b, num_actions)
            obs=observation,  # (b, *obs_shape)
            info=episode_stats,
        )

        # Reset stats for terminal steps
        episode_stats = jax.tree_util.tree_map(
            lambda x: jnp.where(state.terminated, jnp.zeros_like(x), x),
            episode_stats,
        )
        return (state, episode_stats), transition

    # Run self-play for max_num_steps per batch
    rng_key, sub_key = jax.random.split(rng_key)
    key_seq = jax.random.split(sub_key, args.max_num_steps)

    (env_state, episode_stats), traj_batch = jax.lax.scan(
        step_fn,  # type: ignore
        (env_state, episode_stats),
        (jnp.arange(args.max_num_steps), key_seq),  # type: ignore
    )

    # Switch the time and batch axes
    traj_batch = jax.tree_util.tree_map(
        lambda x: jnp.swapaxes(x, 0, 1), traj_batch
    )  # (b, t, ...)
    # Add the batch to the buffer
    buffer_state = buffer_fn.add(
        buffer_state,
        traj_batch,
    )  # type: ignore

    return env_state, episode_stats, buffer_state, traj_batch


# Define training loop
def learning_step(params, target_params, opt_state, buffer_state, key):
    """Perform a learning step using the buffer data."""

    def loss_fn(
        params: hk.Params,
        init_obs: jnp.ndarray,
        s_target: jnp.ndarray,
        policy_target: jnp.ndarray,
        a_seq: jnp.ndarray,
        r_target: jnp.ndarray,
        v_target: jnp.ndarray,
        dt: jnp.ndarray,
    ):
        # Generate root embeddings
        root_embeddings = representation_apply.apply(params, init_obs)

        @jax.remat  # type: ignore
        def unroll_fn(carry, targets):
            total_loss, obs_embedding, mask = carry
            a, r_target, s_target, policy_target, v_target, done, step = targets

            # ACTOR LOSS
            policy_logits_pred = policy_apply.apply(params, obs_embedding)
            policy_loss = optax.softmax_cross_entropy(
                logits=policy_logits_pred, labels=policy_target
            )
            # Zero the policy loss for absorbing states
            policy_loss = policy_loss * mask
            # jax.debug.print(
            #     "policy_target: {policy_target}, policy_logits_pred: {policy_logits_pred}, ",
            #     policy_target=policy_target[:4],
            #     policy_logits_pred=policy_logits_pred[:4],
            # )

            # CRITIC LOSS
            # Zero the policy targets after done as absorbing
            v_target = v_target * mask
            v_pred = critic_apply.apply(params, obs_embedding)
            value_loss = optax.l2_loss(v_pred, v_target)
            # jax.debug.print(
            #     "v_target: {v_target}, v_pred: {v_pred}, ",
            #     v_target=v_target[:4],
            #     v_pred=v_pred[:4],
            # )

            # Scale the gradients of the obs_embedding
            obs_embedding = scale_gradient(obs_embedding, 0.5)
            next_obs_embedding, reward_pred = recurrent_inference.apply(
                params, obs_embedding, a
            )

            # REWARD LOSS
            # Zero reward targets after done as absorbing
            r_target = r_target * mask
            reward_loss = optax.l2_loss(reward_pred, r_target)
            # jax.debug.print(
            #     "r_target: {r_target}, reward_pred: {reward_pred}, ",
            #     r_target=r_target[:4],
            #     reward_pred=reward_pred[:4],
            # )

            # SELF-CONSISTENCY LOSS
            # We want to ensure that the next state embedding is consistent with the target state embedding
            next_obs_proj = projection_apply.apply(
                params, next_obs_embedding
            )  # (b, num_hidden)
            unit_next_obs_proj = next_obs_proj / jnp.linalg.norm(
                next_obs_proj, axis=-1, keepdims=True
            )  # (b, num_hidden)
            unit_s_target = s_target / jnp.linalg.norm(s_target, axis=-1, keepdims=True)
            sc_loss = 1 - jax.vmap(jnp.dot)(unit_next_obs_proj, unit_s_target)
            sc_loss = sc_loss * mask  # No self-consistency loss for absorbing states

            curr_loss = {
                "actor_loss": policy_loss,
                "value_loss": value_loss,
                "reward_loss": reward_loss,
                "self_consistency_loss": sc_loss,
            }
            total_loss = jax.tree_util.tree_map(
                lambda x, y: x + y.mean(), total_loss, curr_loss
            )

            # Update the mask
            mask = jnp.where(done, jnp.zeros_like(mask), mask)
            return (total_loss, next_obs_embedding, mask), curr_loss

        targets = (
            a_seq,
            r_target,
            s_target,
            policy_target,
            v_target,
            dt,
        )
        targets = jax.tree_util.tree_map(
            lambda x: jnp.swapaxes(x, 0, 1), targets
        )  # (t, b, ...)
        targets = (*targets, jnp.arange(0, args.sample_sequence_length))

        init_total_loss = {
            "actor_loss": jnp.array(0.0),
            "value_loss": jnp.array(0.0),
            "reward_loss": jnp.array(0.0),
            "self_consistency_loss": jnp.array(0.0),
        }
        init_mask = jnp.ones((v_target.shape[0],))
        (losses, _, _), _ = jax.lax.scan(
            unroll_fn,
            (init_total_loss, root_embeddings, init_mask),
            targets,
        )
        # Divide by the number of unrolled steps to ensure
        # a consistent scale across different unroll lengths
        losses = jax.tree_util.tree_map(
            lambda x: x / (args.sample_sequence_length - 1), losses
        )
        return (
            losses["actor_loss"]
            + args.vf_coeff * losses["value_loss"]
            + losses["reward_loss"]
            + args.sc_coeff * losses["self_consistency_loss"],
            losses,
        )

    # Sample a batch from the buffer and compute targets
    batch = buffer_fn.sample(buffer_state, key).experience
    (policy_targets, reward_targets, value_targets, s_targets) = get_train_targets(
        batch,
        target_params,
        representation_apply,
        critic_apply,
        args,
    )

    (loss, losses), grads = jax.value_and_grad(loss_fn, has_aux=True)(
        params,
        batch.obs[:, 0],
        s_targets[:, 1 : args.sample_sequence_length + 1],
        policy_targets,
        batch.action[:, : args.sample_sequence_length],
        reward_targets,
        value_targets,
        batch.done[:, : args.sample_sequence_length],
    )
    updates, new_opt_state = optimizer.update(grads, opt_state, params)
    new_params = optax.apply_updates(params=params, updates=updates)
    return new_params, new_opt_state, (loss, losses)


if __name__ == "__main__":

    def train_loop_body(
        carry, iteration
    ):  # -> tuple[tuple[Any, TrajectoryBufferState[Any], Any, Any], Any]:
        (
            rng_key,
            buffer_state,
            opt_state,
            params,
            target_params,
            env_state,
            episode_stats,
        ) = carry
        # Split key for this iteration
        rng_key, subkey = jax.random.split(rng_key)

        def eval_fn():
            R = evaluate(subkey, params)
            jax.debug.print(
                "Iter {i} / {max_num_iters}, Eval Reward: {r}",
                i=iteration,
                max_num_iters=args.max_num_iters,
                r=R.mean(),
            )

        # Run evaluation conditionally
        eval_R = jax.lax.cond(
            iteration % args.eval_interval == 0,
            eval_fn,
            lambda: None,
        )

        # Print/logging (must be done outside JAX or via `jax.debug.print`)
        gumbel_scale = jnp.clip(
            args.gumbel_start
            - (args.gumbel_start - args.gumbel_end)
            * (iteration / args.gumbel_anneal_iterations),
            a_min=args.gumbel_end,
            a_max=args.gumbel_start,
        )
        jax.debug.print(
            "Iteration {i} started. Gumbel scale: {gumbel_scale}, LR: {lr}",
            i=iteration,
            gumbel_scale=gumbel_scale,
            lr=util.linear_schedule(
                iteration,
                args.lr,
                args.lr_anneal_iterations,
                args.min_lr,
            ),
        )

        # Self-play data collection
        env_state, episode_stats, buffer_state, traj_batch = selfplay(
            subkey, params, buffer_state, gumbel_scale, env_state, episode_stats
        )

        # Log the training stats
        episode_returns = traj_batch.info["episode_return"] * traj_batch.info[
            "is_terminal_step"
        ].astype(jnp.float32)  # type: ignore
        episode_lengths = traj_batch.info["episode_length"] * traj_batch.info[
            "is_terminal_step"
        ].astype(jnp.int32)  # type: ignore
        total_terminations = jnp.sum(
            traj_batch.info["is_terminal_step"].astype(jnp.int32)  # type: ignore
        )
        average_return = jnp.sum(episode_returns) / (total_terminations + 1e-8)
        average_length = jnp.sum(episode_lengths) / (total_terminations + 1e-8)
        jax.debug.print(
            "Iter {i} / {max_num_iters}, Train Avg Return: {avg_return:.2f}, Avg Length: {avg_length:.2f}",
            i=iteration,
            max_num_iters=args.max_num_iters,
            avg_return=average_return,
            avg_length=average_length,
        )

        # Training step
        rng_key, subkey = jax.random.split(rng_key)

        # Learning step for epochs
        def scan_fn(carry, epoch):
            params, target_params, opt_state, rng_key = carry
            rng_key, subkey = jax.random.split(rng_key)
            params, opt_state, (loss, losses) = learning_step(
                params, target_params, opt_state, buffer_state, subkey
            )

            # Update target parameters
            target_params = jax.lax.cond(
                epoch % args.target_update_interval == 0,
                lambda: jax.tree_util.tree_map(
                    lambda target, online: args.target_tau * online
                    + (1 - args.target_tau) * target,
                    target_params,
                    params,
                ),
                lambda: target_params,
            )

            return (params, target_params, opt_state, rng_key), (loss, losses)

        (params, target_params, opt_state, _), (loss, losses) = jax.lax.scan(
            scan_fn,
            (params, target_params, opt_state, subkey),
            jnp.arange(args.train_epochs_per_iter),
        )
        # Log the losses
        jax.debug.print(
            "Iter {i} / {max_num_iters}, Loss: {loss:.4f}, Actor Loss: {actor_loss:.4f}, "
            "Value Loss: {value_loss:.4f}, Self-consistency Loss: {self_consistency_loss:.4f}, Reward Loss: {reward_loss:.4f}",
            i=iteration,
            max_num_iters=args.max_num_iters,
            loss=jnp.mean(loss),
            actor_loss=jnp.mean(losses["actor_loss"]),
            value_loss=jnp.mean(losses["value_loss"]),
            self_consistency_loss=jnp.mean(losses["self_consistency_loss"]),
            reward_loss=jnp.mean(losses["reward_loss"]),
        )

        carry = (
            rng_key,
            buffer_state,
            opt_state,
            params,
            target_params,
            env_state,
            episode_stats,
        )
        return carry, eval_R

    # Run scan
    rng_key = jax.random.PRNGKey(seed=args.seed)
    init_rng_key, sub_key = jax.random.split(rng_key)
    keys = jax.random.split(sub_key, args.selfplay_batch_size)
    env_state = jax.vmap(env.init)(keys)

    episode_stats_init = {
        "episode_return": jnp.zeros((args.selfplay_batch_size,)),
        "episode_length": jnp.zeros((args.selfplay_batch_size,), dtype=jnp.int32),
        "is_terminal_step": jnp.zeros((args.selfplay_batch_size,), dtype=bool),
    }

    initial_carry = (
        init_rng_key,
        buffer_state,
        opt_state,
        params,
        target_params,
        env_state,
        episode_stats_init,
    )
    iterations = jnp.arange(args.max_num_iters)

    final_carry, eval_rewards = jax.lax.scan(train_loop_body, initial_carry, iterations)
