import chex
import flashbax as fbx
import jax
import jax.numpy as jnp
import optax
import pgx
from jax_tqdm import scan_tqdm  # type: ignore
from pydantic import BaseModel

import src.lib.util as util
from src.baselines.graph_dqn.auto_reset_with_iteration import auto_reset
from src.baselines.graph_dqn.network import create_graph_q_network
from src.baselines.graph_dqn.util import (
    Transition,
    calc_eps,
    init_model_and_optim,
    make_buffer,
)
from src.util import make_env


class Config(BaseModel):
    seed: int = 23
    env_name: str = "max-ind-set"
    max_num_steps: int = 32
    max_num_iters: int = 5000
    eval_interval: int = 1
    eval_num_actors: int = 256
    selfplay_batch_size: int = 32

    hidden_size: int = 64

    buffer_batch_size: int = 256
    buffer_size: int = 32 * 32 * 32

    epsilon_start: float = 1.0
    epsilon_finish: float = 0.1
    epsilon_anneal_time: int = int(3000)
    learning_start: int = 32  # Iters of max_num_steps to prefill the buffer

    gamma: float = 1.0
    lr: float = 1e-3
    min_lr: float = 1e-4
    optim_eps: float = 1e-5
    lr_linear_decay: bool = True
    lr_anneal_iterations: int = 3000
    max_grad_norm: float = 0.5
    target_tau: float = 1.0
    target_update_interval: int = 500
    train_epochs_per_iter: int = 20

    # Placeholders for dynamic values
    num_actions: int = -1
    is_state_vector: bool = False


# Make env
args = Config()
_env, is_state_vector, num_actions = make_env(
    env_name=args.env_name, use_legal_actions=False
)
args.num_actions = num_actions  # type: ignore

# Make and initialize the model
_model = create_graph_q_network(args.hidden_size)
_params, _optimizer, _opt_state = init_model_and_optim(_env, _model, args)
_target_params = jax.tree.map(lambda x: jnp.copy(x), _params)

# Make trajectory buffer
_buffer_fn, _buffer_state = make_buffer(_env, args)


# Define eval loop
@jax.jit
def evaluate(rng_key: jnp.ndarray, params: optax.Params):
    """Evaluate the model by running selfplay and computing the average reward."""
    rng_key, subkey = jax.random.split(rng_key)
    batch_size = args.eval_num_actors
    keys = jax.random.split(subkey, batch_size)

    # Assign random instance to each environment
    num_instances = _env._instances.num_nodes.shape[0]  # type: ignore
    instance_arange = jnp.arange(num_instances, dtype=jnp.int32)
    rng_key, permutation_key = jax.random.split(rng_key)
    instance_permuted = jax.random.permutation(
        permutation_key, instance_arange, independent=True
    )
    iteration = instance_permuted[:batch_size]  # (num_eval_envs, )

    idxs = jnp.arange(batch_size) % iteration.shape[0]
    iteration = iteration[idxs]

    offset = jnp.zeros(batch_size, dtype=jnp.int32)
    state = jax.vmap(_env.init_v2, in_axes=(0, 0, 0, None, None))(
        keys, iteration, offset, 1, 1
    )

    step_fn = jax.vmap(_env.step)
    ep_return = jnp.zeros_like(state.rewards)  # (num_eval_envs, )
    step = jnp.array(0)
    max_steps = _env._instances.num_nodes[0].astype(jnp.int32)  # type: ignore # Episode length is at most the number of vertices

    def cond_fn(
        tup: tuple[pgx.State, jnp.ndarray, jax.Array, jnp.ndarray],
    ) -> jax.Array:
        """Loop while not all envs are done and all below max_steps."""
        state, _, _, step = tup
        still_running = ~state.terminated.all()
        if max_steps is None:
            return still_running
        return jnp.logical_and(still_running, step < max_steps).all()

    def loop_fn(
        tup: tuple[pgx.State, jnp.ndarray, jax.Array, jnp.ndarray],
    ) -> tuple[pgx.State, jnp.ndarray, jax.Array, jnp.ndarray]:
        state, R, rng_key, step = tup
        q_vals = _model.apply(
            params,
            state.observation["node_features"],  # type: ignore
            state.observation["senders"],  # type: ignore
            state.observation["receivers"],  # type: ignore
            state.observation["aux"],  # type: ignore
        )

        # Mask out illegal actions if needed
        q_vals: jnp.ndarray = jnp.where(
            state.legal_action_mask,
            q_vals,
            jnp.full_like(q_vals, -100.0),
        )
        greedy_action = jnp.argmax(q_vals, axis=-1)

        rng_key, _rng = jax.random.split(rng_key)
        keys = jax.random.split(_rng, state.observation["node_features"].shape[0])
        state = step_fn(state, greedy_action, keys)
        return state, R + state.rewards, rng_key, step + 1

    # Loop until all environments are done
    state, ep_return, _, _ = jax.lax.while_loop(
        cond_fn, loop_fn, (state, ep_return, rng_key, step)
    )
    opt = state._x.optimal_value  # type: ignore
    approximation_ratio = ep_return / opt  # type: ignore
    return approximation_ratio.mean()


@jax.jit
def selfplay(
    rng_key: jax.Array,
    params: optax.Params,
    buffer_state: fbx.trajectory_buffer.TrajectoryBufferState,
    env_state: pgx.State,
    episode_stats: dict[str, jnp.ndarray],
    step_num: int,
):
    # @scan_tqdm(args.max_num_steps)
    def step_fn(
        carry: tuple[
            pgx.State,
            dict[str, jnp.ndarray],
            fbx.trajectory_buffer.TrajectoryBufferState,
        ],
        iter_data: jnp.ndarray,
    ) -> tuple[
        tuple[
            pgx.State,
            dict[str, jnp.ndarray],
            fbx.trajectory_buffer.TrajectoryBufferState,
        ],
        Transition,
    ]:
        state, episode_stats, buffer_state = carry
        _, key = iter_data
        key1, key2 = jax.random.split(key)  # (2,), (2,)
        observation = state.observation

        # Epsilon-greedy action selection
        q_vals = _model.apply(
            params,
            observation["node_features"],
            observation["senders"],
            observation["receivers"],
            observation["aux"],
        )  # (b, num_nodes)

        # Mask out illegal actions if needed
        q_vals: jnp.ndarray = jnp.where(
            state.legal_action_mask,
            q_vals,
            jnp.full_like(q_vals, -100.0),
        )
        greedy_action = jnp.argmax(q_vals, axis=-1)  # (b, num_nodes)

        noise = jax.random.uniform(key1, q_vals.shape)  # (b, num_nodes)
        masked_noise = jnp.where(
            state.legal_action_mask, noise, jnp.full_like(noise, -1.0)
        )  # (b, num_nodes)
        random_action = jnp.argmax(masked_noise, axis=-1)  # (b)

        epsilon = calc_eps(
            step_num,
            args.epsilon_start,
            args.epsilon_finish,
            args.epsilon_anneal_time,
        )  # (b,)
        action = jax.lax.cond(
            jax.random.uniform(key2) < epsilon,
            lambda: random_action,
            lambda: greedy_action,
        )

        keys = jax.random.split(key2, state.observation["node_features"].shape[0])
        state = jax.vmap(auto_reset(_env.step, _env.init_v2))(state, action, keys)

        # Update episode stats
        episode_stats["episode_return"] += state.rewards[:, -1] / state._x.optimal_value  # type: ignore
        episode_stats["episode_length"] += 1
        episode_stats["is_terminal_step"] = state.terminated

        # Create transition
        transition = Transition(
            done=state.terminated,  # (b,)
            action=jnp.asarray(action),  # (b,)
            reward=state.rewards[:, 0],  # (b,)
            obs=observation,  # (b, *obs_shape)
            info=episode_stats,
        )
        buffer_state = _buffer_fn.add(
            buffer_state,
            transition,
        )  # Add transition to the buffer

        # Reset stats for terminal steps
        episode_stats = jax.tree_util.tree_map(
            lambda x: jnp.where(state.terminated, jnp.zeros_like(x), x),
            episode_stats,
        )
        return (state, episode_stats, buffer_state), transition

    # Run self-play for max_num_steps per batch
    rng_key, sub_key = jax.random.split(rng_key)
    key_seq = jax.random.split(sub_key, args.max_num_steps)

    (env_state, episode_stats, buffer_state), traj_batch = jax.lax.scan(
        step_fn,  # type: ignore
        (env_state, episode_stats, buffer_state),
        (jnp.arange(args.max_num_steps), key_seq),  # type: ignore
    )

    return env_state, episode_stats, buffer_state, traj_batch


@jax.jit
def selfplay_scan_fn(carry, iteration):
    """Scan function for self-play to prefill buffer."""
    (
        rng_key,
        buffer_state,
        opt_state,
        params,
        target_params,
        env_state,
        episode_stats,
    ) = carry

    # Run self-play
    (
        env_state,
        episode_stats,
        buffer_state,
        traj_batch,
    ) = selfplay(rng_key, params, buffer_state, env_state, episode_stats, iteration)

    return (
        rng_key,
        buffer_state,
        opt_state,
        params,
        target_params,
        env_state,
        episode_stats,
    ), traj_batch


@jax.jit
def learning_step(rng_key, params, target_params, opt_state, buffer_state):
    learn_batch = _buffer_fn.sample(buffer_state, rng_key).experience
    q_next_target = _model.apply(
        target_params,
        learn_batch.second.obs["node_features"],
        learn_batch.second.obs["senders"],
        learn_batch.second.obs["receivers"],
        learn_batch.second.obs["aux"],
    )  # (b, num_nodes)
    q_next_target = jnp.max(q_next_target, axis=-1)  # (b,)
    target = (
        learn_batch.first.reward
        + (1 - learn_batch.first.done) * args.gamma * q_next_target
    )  # (b, )

    def _loss_fn(params: optax.Params) -> jnp.ndarray:
        q_vals = _model.apply(
            params,
            learn_batch.first.obs["node_features"],
            learn_batch.first.obs["senders"],
            learn_batch.first.obs["receivers"],
            learn_batch.first.obs["aux"],
        )
        chosen_q_vals = jnp.take_along_axis(
            q_vals, learn_batch.first.action[:, jnp.newaxis], axis=-1
        ).squeeze(-1)

        td_errors = chosen_q_vals - target  # (b, )
        losses = 0.5 * jnp.square(td_errors)  # (b, )
        chex.assert_shape(losses, (args.buffer_batch_size,))
        loss = jnp.mean(losses)
        return loss

    loss, grads = jax.value_and_grad(_loss_fn)(params)
    updates, opt_state = _optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)

    return params, opt_state, loss


# Define scan function for training
@jax.jit
def learn_scan_fn(carry, epoch):
    """Scan function for training."""
    rng_key, params, target_params, opt_state, buffer_state = carry
    rng_key, subkey = jax.random.split(rng_key)

    # Perform a learning step
    params, opt_state, loss = learning_step(
        subkey, params, target_params, opt_state, buffer_state
    )

    # Update target parameters
    target_params = jax.lax.cond(
        (epoch + 1) % args.target_update_interval == 0,
        lambda: jax.tree_util.tree_map(
            lambda target, online: args.target_tau * online
            + (1 - args.target_tau) * target,
            target_params,
            params,
        ),
        lambda: target_params,
    )

    return (rng_key, params, target_params, opt_state, buffer_state), loss


if __name__ == "__main__":

    @jax.jit
    def eval_fn(subkey, iteration, params):
        R = evaluate(subkey, params)
        jax.debug.print(
            "Iter {i} / {max_num_iters}, Eval Reward: {r}",
            i=iteration,
            max_num_iters=args.max_num_iters,
            r=R.mean(),
        )

    @jax.jit
    def train_loop_body(carry, iteration):
        (
            rng_key,
            buffer_state,
            opt_state,
            params,
            target_params,
            env_state,
            episode_stats,
        ) = carry

        # Training step
        rng_key, subkey = jax.random.split(rng_key)

        # Run evaluation conditionally
        eval_R = jax.lax.cond(
            iteration % args.eval_interval == 0,
            eval_fn,
            lambda *_: None,
            subkey,
            iteration,
            params,
        )

        # Self-play data collection
        env_state, episode_stats, buffer_state, traj_batch = selfplay(
            subkey,
            params,
            buffer_state,
            env_state,
            episode_stats,
            iteration,
        )

        # Log the training stats
        episode_returns = traj_batch.info["episode_return"] * traj_batch.info[
            "is_terminal_step"
        ].astype(jnp.float32)  # type: ignore
        episode_lengths = traj_batch.info["episode_length"] * traj_batch.info[
            "is_terminal_step"
        ].astype(jnp.int32)  # type: ignore
        total_terminations = jnp.sum(
            traj_batch.info["is_terminal_step"].astype(jnp.int32)  # type: ignore
        )
        average_return = jnp.sum(episode_returns) / (total_terminations + 1e-8)
        average_length = jnp.sum(episode_lengths) / (total_terminations + 1e-8)

        # Perform learning steps
        initial_carry = (rng_key, params, target_params, opt_state, buffer_state)
        (rng_key, params, target_params, opt_state, buffer_state), losses = (
            jax.lax.scan(
                learn_scan_fn, initial_carry, jnp.arange(args.train_epochs_per_iter)
            )
        )

        # Log losses
        current_lr = util.linear_schedule(
            iteration * args.train_epochs_per_iter,
            args.lr,
            args.lr_anneal_iterations * args.train_epochs_per_iter,
            args.min_lr,
        )
        current_epsilon = calc_eps(
            iteration,
            args.epsilon_start,
            args.epsilon_finish,
            args.epsilon_anneal_time,
        )
        jax.debug.print(
            "Iter {i} / {max_num_iters}, Loss: {loss:.4f} (LR: {lr:.6f}, EPS: {eps:.4f}), Train Avg Return: {avg_return:.2f}, Avg Length: {avg_length:.2f}",
            i=iteration,
            max_num_iters=args.max_num_iters,
            loss=jnp.mean(losses),
            lr=current_lr,
            eps=current_epsilon,
            avg_return=average_return,
            avg_length=average_length,
        )

        carry = (
            rng_key,
            buffer_state,
            opt_state,
            params,
            target_params,
            env_state,
            episode_stats,
        )
        return carry, eval_R

    # Run scan
    rng_key = jax.random.PRNGKey(seed=args.seed)
    init_rng_key, sub_key = jax.random.split(rng_key)
    keys = jax.random.split(sub_key, args.selfplay_batch_size)

    iteration = jnp.zeros(args.selfplay_batch_size, dtype=jnp.int32)
    offset = jnp.arange(args.selfplay_batch_size, dtype=jnp.int32)
    env_state = jax.vmap(_env.init_v2, in_axes=(0, 0, 0, None, None))(
        keys, iteration, offset, args.selfplay_batch_size, 0
    )
    episode_stats_init = {
        "episode_return": jnp.zeros((args.selfplay_batch_size,)),
        "episode_length": jnp.zeros((args.selfplay_batch_size,), dtype=jnp.int32),
        "is_terminal_step": jnp.zeros((args.selfplay_batch_size,), dtype=bool),
    }
    initial_carry = (
        init_rng_key,
        _buffer_state,
        _opt_state,
        _params,
        _target_params,
        env_state,
        episode_stats_init,
    )
    # Self-play to fill the buffer
    initial_carry, traj_batch = jax.lax.scan(
        selfplay_scan_fn, initial_carry, jnp.arange(args.learning_start)
    )

    iterations = jnp.arange(args.max_num_iters)
    final_carry, eval_rewards = jax.lax.scan(train_loop_body, initial_carry, iterations)
