import chex
import flashbax as fbx
import jax
import jax.numpy as jnp
import optax
import pgx
from jax_tqdm import loop_tqdm  # type: ignore
from pydantic import BaseModel

import src.lib.util as util
from datasets.stochastic_mis.cvar_estimation import estimate_cvar
from src.baselines.graph_dqn.auto_reset_with_iteration import auto_reset
from src.baselines.graph_tql_o.network import create_graph_networks
from src.baselines.graph_tql_o.util import (
    Transition,
    calc_eps,
    init_model_and_optim,
    make_buffer,
)
from src.baselines.qrdqn.util import quantile_huber_loss
from src.util import make_env


class Config(BaseModel):
    seeds: list[int] = []
    seed: int = 23
    env_name: str = "stochastic-max-ind-set-1"
    max_num_steps: int = 32
    max_num_iters: int = 2000
    eval_interval: int = 25
    eval_num_actors: int = 1024
    selfplay_batch_size: int = 32
    num_quantile_samples: int = 1024

    hidden_size: int = 64
    num_quantiles: int = 64
    alpha_cvar: float = 0.25
    huber_param: float = 1.0

    buffer_batch_size: int = 256
    buffer_size: int = 32 * 32 * 64

    epsilon_start: float = 0.5
    epsilon_finish: float = 0.1
    epsilon_anneal_time: int = int(1000)  # 5000
    learning_start: int = 64  # Iters of max_num_steps to prefill the buffer

    gamma: float = 1.0
    lr: float = 1e-4
    min_lr: float = 1e-5
    optim_eps: float = 1e-5
    lr_linear_decay: bool = True
    lr_anneal_iterations: int = 2000  # 5000
    max_grad_norm: float = 0.5
    target_tau: float = 1.0
    target_update_interval: int = 500  # 1000
    train_epochs_per_iter: int = 20

    # Placeholders for dynamic values
    num_actions: int = -1
    is_state_vector: bool = False


def run_experiment(args):
    # Make env
    # args = Config()
    _env, is_state_vector, num_actions = make_env(
        env_name=args.env_name, use_legal_actions=False
    )
    args.num_actions = num_actions  # type: ignore

    # Make and initialize the model
    _init_model, _qr_model, _tql_model = create_graph_networks(args)
    _params, _optimizer, _opt_state = init_model_and_optim(_env, _init_model, args)
    _target_params = jax.tree.map(lambda x: jnp.copy(x), _params)

    # Make trajectory buffer
    _buffer_fn, _buffer_state = make_buffer(_env, args)

    # Make tau_hats for quantile loss
    tau_hats = util.make_tau_hats(args.num_quantiles)

    # Define eval loop
    @jax.jit
    def evaluate(rng_key: jnp.ndarray, params: optax.Params):
        """Evaluate the model by running selfplay and computing the average reward."""
        rng_key, subkey = jax.random.split(rng_key)
        batch_size = 512
        keys = jax.random.split(subkey, batch_size)

        total_instances = env._instances.num_nodes.shape[0] // 2  # type: ignore
        num_instances_per_batch = 512
        num_batches = (total_instances) // num_instances_per_batch
        assert total_instances % num_instances_per_batch == 0

        @loop_tqdm(num_batches)
        def eval_one_batch(batch_idx, carry):
            rng_key, acc_mean, acc_cvar, acc_opt_cvar = carry

            # Instances for this batch
            iteration = (
                jnp.arange(num_instances_per_batch, dtype=jnp.int32)
                + batch_idx * num_instances_per_batch
            )
            iteration = iteration.reshape((-1, 1))
            iteration = jnp.repeat(iteration, batch_size // iteration.shape[0], axis=1)
            iteration = iteration.reshape((-1,))

            offset = jnp.zeros(batch_size, dtype=jnp.int32)
            state = jax.vmap(_env.init_v2, in_axes=(0, 0, 0, None, None))(
                keys, iteration, offset, 1, 1
            )

            # ---- optimal CVaR for this batch ----
            opt_cvar_batch = jnp.mean(state._x.optimal_cvar_value)  # type: ignore
            step_fn = jax.vmap(_env.step)
            ep_return = jnp.zeros_like(state.rewards)
            step = jnp.array(0)
            max_steps = _env._instances.num_nodes[0].astype(jnp.int32)  # type: ignore

            def cond_fn(tup):
                state, _, _, step, _ = tup
                still_running = ~state.terminated.all()
                if max_steps is None:
                    return still_running
                return jnp.logical_and(still_running, step < max_steps).all()

            def loop_fn(tup):
                state, R, rng_key, step, actions = tup
                tql_out = _tql_model.apply(
                    params,
                    state.observation["node_features"],  # type: ignore
                    state.observation["senders"],  # type: ignore
                    state.observation["receivers"],  # type: ignore
                    state.observation["aux"],  # type: ignore
                )

                # Mask out illegal actions if needed
                q_vals: jnp.ndarray = jnp.where(
                    state.legal_action_mask,
                    tql_out.q_values,
                    jnp.full_like(tql_out.q_values, -1000.0),
                )  # type: ignore
                greedy_action = jnp.argmax(q_vals, axis=-1)

                rng_key, _rng = jax.random.split(rng_key)
                keys = jax.random.split(
                    _rng, state.observation["node_features"].shape[0]
                )
                state = step_fn(state, greedy_action, keys)

                # Write action into buffer at position `step`
                actions = actions.at[step].set(greedy_action)
                return state, R + state.rewards, rng_key, step + 1, actions

            max_steps = 40
            actions_init = jnp.zeros((max_steps, batch_size), dtype=jnp.int32) - 1
            carry_init = (state, ep_return, rng_key, step, actions_init)
            state, R, rng_key, _, actions_taken = jax.lax.while_loop(
                cond_fn, loop_fn, carry_init
            )
            actions_taken = actions_taken.swapaxes(0, 1)  # (b, max_steps)
            node_types = jnp.take_along_axis(
                state.observation["node_types"], actions_taken, axis=1
            )
            # If an action is -1 (not taken), set node type to 2 (0 reward)
            node_types = jnp.where(actions_taken == -1, 2, node_types)

            # # ---- empirical CVaR for this batch ----
            # R = R.reshape((-1,))
            # R_mean = jnp.mean(R)
            # R = R.reshape((num_instances_per_batch, -1))
            # R_cvars = jax.vmap(util.cvar, in_axes=(0, None))(R, 0.25)
            # R_cvar = jnp.mean(R_cvars)

            R_cvars = jax.vmap(estimate_cvar, in_axes=(0, 0, None, None))(
                jax.random.split(rng_key, batch_size),
                node_types,
                10000,
                0.25,
            )
            R = jax.vmap(estimate_cvar, in_axes=(0, 0, None, None))(
                jax.random.split(rng_key, batch_size),
                node_types,
                10000,
                1.0,
            )
            R_mean = jnp.mean(R)
            R_cvar = jnp.mean(R_cvars)

            return (
                rng_key,
                acc_mean + R_mean,
                acc_cvar + R_cvar,
                acc_opt_cvar + opt_cvar_batch,
            )

        # Run across all batches
        rng_key, total_mean, total_cvar, total_opt_cvar = jax.lax.fori_loop(
            0, num_batches, eval_one_batch, (rng_key, 0.0, 0.0, 0.0)
        )

        # Average across batches
        R_mean = total_mean / num_batches
        R_cvar = total_cvar / num_batches
        R_opt_cvar = total_opt_cvar / num_batches
        return R_mean, R_cvar, R_opt_cvar

    @jax.jit
    def selfplay(
        rng_key: jax.Array,
        params: optax.Params,
        buffer_state: fbx.trajectory_buffer.TrajectoryBufferState,
        env_state: pgx.State,
        episode_stats: dict[str, jnp.ndarray],
        step_num: int,
    ):
        # @scan_tqdm(args.max_num_steps)
        def step_fn(
            carry: tuple[
                pgx.State,
                dict[str, jnp.ndarray],
                fbx.trajectory_buffer.TrajectoryBufferState,
            ],
            iter_data: jnp.ndarray,
        ) -> tuple[
            tuple[
                pgx.State,
                dict[str, jnp.ndarray],
                fbx.trajectory_buffer.TrajectoryBufferState,
            ],
            Transition,
        ]:
            state, episode_stats, buffer_state = carry
            _, key = iter_data
            key1, key2 = jax.random.split(key)  # (2,), (2,)
            observation = state.observation

            tql_out = _tql_model.apply(
                params,
                state.observation["node_features"],  # type: ignore
                state.observation["senders"],  # type: ignore
                state.observation["receivers"],  # type: ignore
                state.observation["aux"],  # type: ignore
            )

            # Mask out illegal actions if needed
            q_vals: jnp.ndarray = jnp.where(
                state.legal_action_mask,
                tql_out.q_values,
                jnp.full_like(tql_out.q_values, -1000.0),
            )  # type: ignore
            greedy_action = jnp.argmax(q_vals, axis=-1)

            noise = jax.random.uniform(key1, q_vals.shape)  # (b, num_nodes)
            masked_noise = jnp.where(
                state.legal_action_mask, noise, jnp.full_like(noise, -1.0)
            )  # (b, num_nodes)
            random_action = jnp.argmax(masked_noise, axis=-1)  # (b)

            epsilon = calc_eps(
                step_num,
                args.epsilon_start,
                args.epsilon_finish,
                args.epsilon_anneal_time,
            )  # (b,)
            action = jax.lax.cond(
                jax.random.uniform(key2) < epsilon,
                lambda: random_action,
                lambda: greedy_action,
            )

            keys = jax.random.split(key2, state.observation["node_features"].shape[0])
            state = jax.vmap(auto_reset(_env.step, _env.init_v2))(state, action, keys)

            # Update episode stats
            initial = episode_stats["episode_length"] == 0
            episode_stats["episode_return"] += state.rewards[:, -1]
            episode_stats["episode_length"] += 1
            episode_stats["is_terminal_step"] = state.terminated

            # Create transition
            transition = Transition(
                accumulated_return=episode_stats["episode_return"],
                done=state.terminated,  # (b,)
                initial=initial,  # (b,)
                action=jnp.asarray(action),  # (b,)
                reward=state.rewards[:, 0],  # (b,)
                obs=observation,  # (b, *obs_shape)
                info=episode_stats,
            )
            buffer_state = _buffer_fn.add(
                buffer_state,
                transition,
            )  # Add transition to the buffer

            # Reset stats for terminal steps
            episode_stats = jax.tree_util.tree_map(
                lambda x: jnp.where(state.terminated, jnp.zeros_like(x), x),
                episode_stats,
            )
            return (state, episode_stats, buffer_state), transition

        # Run self-play for max_num_steps per batch
        rng_key, sub_key = jax.random.split(rng_key)
        key_seq = jax.random.split(sub_key, args.max_num_steps)

        (env_state, episode_stats, buffer_state), traj_batch = jax.lax.scan(
            step_fn,  # type: ignore
            (env_state, episode_stats, buffer_state),
            (jnp.arange(args.max_num_steps), key_seq),  # type: ignore
        )

        return env_state, episode_stats, buffer_state, traj_batch

    @jax.jit
    def selfplay_scan_fn(carry, iteration):
        """Scan function for self-play to prefill buffer."""
        (
            rng_key,
            buffer_state,
            opt_state,
            params,
            target_params,
            env_state,
            episode_stats,
            eval_R,
        ) = carry

        # Run self-play
        (
            env_state,
            episode_stats,
            buffer_state,
            traj_batch,
        ) = selfplay(rng_key, params, buffer_state, env_state, episode_stats, iteration)

        return (
            rng_key,
            buffer_state,
            opt_state,
            params,
            target_params,
            env_state,
            episode_stats,
            eval_R,
        ), traj_batch

    @jax.jit
    def learning_step(rng_key, params, target_params, opt_state, buffer_state):
        learn_batch = _buffer_fn.sample(buffer_state, rng_key).experience

        # QR Model
        qr_next_out = _qr_model.apply(
            target_params,
            learn_batch.second.obs["node_features"],
            learn_batch.second.obs["senders"],
            learn_batch.second.obs["receivers"],
            learn_batch.second.obs["aux"],
        )  # (b, num_nodes)
        tql_out = _tql_model.apply(
            params,
            learn_batch.second.obs["node_features"],
            learn_batch.second.obs["senders"],
            learn_batch.second.obs["receivers"],
            learn_batch.second.obs["aux"],
        )

        rng_key, sample_key = jax.random.split(rng_key)

        greedy_actions = jnp.argmax(tql_out.q_values, axis=-1)
        qr_next_target = jnp.take_along_axis(
            qr_next_out.q_dist,
            jnp.expand_dims(greedy_actions, axis=(-1, -2)),  # (b, num_quantiles, 1)
            axis=-1,
        ).squeeze(-1)

        tql_next_target = (
            learn_batch.first.accumulated_return[:, None]
            + (1 - learn_batch.first.done[:, None]) * args.gamma * qr_next_target
        )  # (b, )
        qr_next_target = (
            learn_batch.first.reward[:, None]
            + (1 - learn_batch.first.done[:, None]) * args.gamma * qr_next_target
        )  # (b, )

        def _loss_fn(params: optax.Params) -> jnp.ndarray:
            # QR Loss
            qr_out = _qr_model.apply(
                params,
                learn_batch.first.obs["node_features"],
                learn_batch.first.obs["senders"],
                learn_batch.first.obs["receivers"],
                learn_batch.first.obs["aux"],
            )
            chosen_action_q_dists = jnp.take_along_axis(
                qr_out.q_dist, learn_batch.first.action[:, None, None], axis=-1
            ).squeeze(-1)

            qr_losses = jax.vmap(quantile_huber_loss, in_axes=(0, None, 0, None, None))(
                chosen_action_q_dists,
                tau_hats,
                qr_next_target,
                args.huber_param,
                True,
            )
            chex.assert_shape(qr_losses, (args.buffer_batch_size,))
            qr_loss = jnp.mean(qr_losses)

            # TQL Loss
            tql_out = _tql_model.apply(
                params,
                learn_batch.second.obs["node_features"],
                learn_batch.second.obs["senders"],
                learn_batch.second.obs["receivers"],
                learn_batch.second.obs["aux"],
            )
            chosen_action_q_dists = jnp.take_along_axis(
                tql_out.q_dist, learn_batch.first.action[:, None, None], axis=-1
            ).squeeze(-1)
            rh_losses = jax.vmap(quantile_huber_loss, in_axes=(0, None, 0, None, None))(
                chosen_action_q_dists,
                tau_hats,
                tql_next_target,
                args.huber_param,
                True,
            )
            chex.assert_shape(rh_losses, (args.buffer_batch_size,))
            rh_loss = jnp.mean(rh_losses)
            return qr_loss + rh_loss

        loss, grads = jax.value_and_grad(_loss_fn)(params)
        updates, opt_state = _optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)

        return params, opt_state, loss

    # Define scan function for training
    @jax.jit
    def learn_scan_fn(carry, epoch):
        """Scan function for training."""
        rng_key, params, target_params, opt_state, buffer_state = carry
        rng_key, subkey = jax.random.split(rng_key)

        # Perform a learning step
        params, opt_state, loss = learning_step(
            subkey, params, target_params, opt_state, buffer_state
        )

        # Update target parameters
        target_params = jax.lax.cond(
            (epoch + 1) % args.target_update_interval == 0,
            lambda: jax.tree_util.tree_map(
                lambda target, online: args.target_tau * online
                + (1 - args.target_tau) * target,
                target_params,
                params,
            ),
            lambda: target_params,
        )

        return (rng_key, params, target_params, opt_state, buffer_state), loss

    @jax.jit
    def eval_fn(subkey, iteration, params):
        R, R_cvar, R_opt_cvar = evaluate(subkey, params)
        jax.debug.print(
            "Iter {i} / {max_num_iters}, Eval Reward: {r}, Eval CVaR: {r_cvar}, Eval Opt CVaR: {r_opt_cvar}",
            i=iteration,
            max_num_iters=args.max_num_iters,
            r=R.mean(),
            r_cvar=R_cvar.mean(),
            r_opt_cvar=R_opt_cvar.mean(),
        )
        return R.mean(), R_cvar.mean()

    @jax.jit
    def train_loop_body(carry, iteration):
        (
            rng_key,
            buffer_state,
            opt_state,
            params,
            target_params,
            env_state,
            episode_stats,
            last_eval_reward,
        ) = carry

        # Training step
        rng_key, subkey = jax.random.split(rng_key)

        # Run evaluation conditionally
        eval_R = jax.lax.cond(
            iteration % args.eval_interval == 0,
            eval_fn,
            lambda *_: last_eval_reward,
            subkey,
            iteration,
            params,
        )

        # Self-play data collection
        env_state, episode_stats, buffer_state, traj_batch = selfplay(
            subkey,
            params,
            buffer_state,
            env_state,
            episode_stats,
            iteration,
        )

        # Log the training stats
        episode_returns = traj_batch.info["episode_return"] * traj_batch.info[
            "is_terminal_step"
        ].astype(jnp.float32)  # type: ignore
        episode_lengths = traj_batch.info["episode_length"] * traj_batch.info[
            "is_terminal_step"
        ].astype(jnp.int32)  # type: ignore
        total_terminations = jnp.sum(
            traj_batch.info["is_terminal_step"].astype(jnp.int32)  # type: ignore
        )
        average_return = jnp.sum(episode_returns) / (total_terminations + 1e-8)
        average_length = jnp.sum(episode_lengths) / (total_terminations + 1e-8)

        # Perform learning steps
        initial_carry = (rng_key, params, target_params, opt_state, buffer_state)
        (rng_key, params, target_params, opt_state, buffer_state), losses = (
            jax.lax.scan(
                learn_scan_fn,
                initial_carry,
                iteration * args.train_epochs_per_iter
                + jnp.arange(args.train_epochs_per_iter),
            )
        )

        # Log losses
        current_lr = util.linear_schedule(
            iteration * args.train_epochs_per_iter,
            args.lr,
            args.lr_anneal_iterations * args.train_epochs_per_iter,
            args.min_lr,
        )
        current_epsilon = calc_eps(
            iteration,
            args.epsilon_start,
            args.epsilon_finish,
            args.epsilon_anneal_time,
        )
        jax.debug.print(
            "Iter {i} / {max_num_iters}, Loss: {loss:.4f} (LR: {lr:.6f}, EPS: {eps:.4f}), Train Avg Return: {avg_return:.2f}, Avg Length: {avg_length:.2f}",
            i=iteration,
            max_num_iters=args.max_num_iters,
            loss=jnp.mean(losses),
            lr=current_lr,
            eps=current_epsilon,
            avg_return=average_return,
            avg_length=average_length,
        )

        carry = (
            rng_key,
            buffer_state,
            opt_state,
            params,
            target_params,
            env_state,
            episode_stats,
            eval_R,
        )
        R, R_cvar = eval_R
        json_log = {
            "num_steps": (iteration + 1) * args.max_num_steps,
            "training_steps": (iteration + 1) * args.train_epochs_per_iter,
            "iteration": iteration,
            "loss": jnp.mean(losses),
            "train_average_return": average_return,
            "train_average_length": average_length,
            "epsilon": current_epsilon,
            "learning_rate": current_lr,
            "last_eval_reward": R,
            "last_eval_cvar": R_cvar,
        }
        return carry, json_log

    # Run scan
    rng_key = jax.random.PRNGKey(seed=args.seed)
    init_rng_key, sub_key = jax.random.split(rng_key)
    keys = jax.random.split(sub_key, args.selfplay_batch_size)

    iteration = jnp.zeros(args.selfplay_batch_size, dtype=jnp.int32)
    offset = jnp.arange(args.selfplay_batch_size, dtype=jnp.int32)
    env_state = jax.vmap(_env.init_v2, in_axes=(0, 0, 0, None, None))(
        keys, iteration, offset, args.selfplay_batch_size, 0
    )
    episode_stats_init = {
        "episode_return": jnp.zeros((args.selfplay_batch_size,)),
        "episode_length": jnp.zeros((args.selfplay_batch_size,), dtype=jnp.int32),
        "is_terminal_step": jnp.zeros((args.selfplay_batch_size,), dtype=bool),
    }
    initial_carry = (
        init_rng_key,
        _buffer_state,
        _opt_state,
        _params,
        _target_params,
        env_state,
        episode_stats_init,
        (jnp.array(0.0), jnp.array(0.0)),  # last_eval_reward
    )
    # Self-play to fill the buffer
    initial_carry, traj_batch = jax.lax.scan(
        selfplay_scan_fn, initial_carry, jnp.arange(args.learning_start)
    )

    iterations = jnp.arange(args.max_num_iters)
    final_carry, json_logs = jax.lax.scan(train_loop_body, initial_carry, iterations)

    return json_logs


if __name__ == "__main__":
    import os

    import numpy as onp

    def save_logs(logs):
        """
        Save logs to a specified file path.
        """
        log_file_path = f"./logs/graph_tql/s-mis-easy/logger.log"
        os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
        numpy_logs = onp.array(logs)
        onp.save(log_file_path, numpy_logs)

    json_logs = run_experiment(Config())
    save_logs(json_logs)
