import pathlib
from functools import partial

import flax.nnx as nnx
import jax
import jax.numpy as jnp
import jax.tree_util as tree_util
import orbax.checkpoint as ocp

from nais.gflownet import GFlowNet, GFlowNetState
from nais.gym.base import EnvState, LogRewardBase


class Checkpointer:
    def __init__(self, ckpt_dir: str):
        if not pathlib.Path(ckpt_dir).is_dir():
            self.ckpt_dir = ocp.test_utils.erase_and_create_empty(ckpt_dir)
        else:
            self.ckpt_dir = pathlib.Path(ckpt_dir)

        self.checkpointer = ocp.StandardCheckpointer()

    def save(self, gflownet: GFlowNet, step: int = -1):
        model_path = self.ckpt_dir / f"state-at-{step}"
        print(f"Saving model to {model_path}")
        _, gfn_state = nnx.split(gflownet)
        self.checkpointer.save(model_path, gfn_state)
        self.checkpointer.wait_until_finished()

    def load(self, gflownet: GFlowNet, step: int = -1) -> GFlowNet:
        model_path = self.ckpt_dir / f"state-at-{step}"
        print(f"Loading model from {model_path}")
        gflownet_def, gfn_state = nnx.split(gflownet)
        restored_state = self.checkpointer.restore(model_path, gfn_state)
        gflownet = nnx.merge(gflownet_def, restored_state)
        return gflownet


@partial(jax.vmap, in_axes=(1, 0), out_axes=1)
def roll_column(col, shift):
    if isinstance(col, jax.Array) and col.ndim >= 1:
        rolled_col = jnp.roll(col, -shift, axis=0)
        flipped_col = jnp.flip(rolled_col, axis=0)
        return flipped_col
    return col


# jnp.ndarray is not idiomatic
def roll_by_last_active(x: jax.Array, last_active_idx: jax.Array):
    if len(x.shape) < 2:
        # For attributes without a batch dimension (e.g., shared-batch metadata, such as step size)
        return x
    # vmap over the batch dimension

    # We only do this if the roll_column is 2D
    return roll_column(x, last_active_idx)


def evaluate_on_transitions_bcw(
    carry: tuple[GFlowNet, int],
    _,
    actions: jax.Array,
    is_active: jax.Array,
    states: EnvState,
):
    gflownet, idx = carry

    def slice_fn(x):
        if isinstance(x, jax.Array) and x.ndim >= 1:
            return x[idx]
        return x

    env_state_at_idx = tree_util.tree_map(slice_fn, states)

    out_pb = gflownet.pb.sample_actions(env_state_at_idx, actions=actions[idx])

    gflownet.state = gflownet.state.replace(idx=idx + 1)

    log_pb_t = jnp.where(is_active[idx], out_pb.log_pb, 0.0)

    return (gflownet, idx + 1), log_pb_t


def rollout_fwd(
    gflownet: GFlowNet, apply_fn: callable, log_reward: LogRewardBase, key: jax.Array
) -> tuple[tuple[GFlowNet, jax.Array], jax.Array]:
    gflownet.lazy_init_fwd()
    (res, key), scanned_output = jax.lax.scan(
        apply_fn,
        (gflownet, key),
        length=gflownet.state.env_state.max_trajectory_length,
    )

    log_rewards = jax.lax.stop_gradient(log_reward(res.state.env_state))

    actions, is_active, states = scanned_output

    # We should reorder these columns properly
    last_active_idx = jnp.sum(is_active, axis=0)

    rolled_actions = roll_by_last_active(actions, last_active_idx)
    rolled_is_active = roll_by_last_active(is_active, last_active_idx)

    rolled_states = tree_util.tree_map(lambda x: roll_by_last_active(x, last_active_idx), states)

    # Now we evaluate the backward policy on the sampled states with the given actions
    res.lazy_init_bcw()
    _, log_pb = jax.lax.scan(
        partial(
            evaluate_on_transitions_bcw,
            actions=rolled_actions,
            is_active=rolled_is_active,
            states=rolled_states,
        ),
        init=(res, 0),
        xs=None,
        length=rolled_actions.shape[0],
    )

    res.state = res.state.replace(log_pb=log_pb.T)

    return (res, key), log_rewards


def evaluate_on_transitions_fwd(
    carry: tuple[GFlowNet, int],
    _,
    *,
    actions: jax.Array,
    is_active: jax.Array,
    states: EnvState,
):
    gflownet, idx = carry

    def slice_fn(x):
        if isinstance(x, jax.Array) and x.ndim >= 1:
            return x[idx]
        return x

    env_state_at_idx = tree_util.tree_map(slice_fn, states)

    out_pf = gflownet.pf.sample_actions(env_state_at_idx, actions=actions[idx])

    gflownet.state = gflownet.state.replace(idx=idx + 1)

    log_pf_t = jnp.where(is_active[idx], out_pf.log_pf, 0.0)

    return (gflownet, idx + 1), log_pf_t


def rollout_bcw(
    gflownet: GFlowNet,
    backward_fn: callable,
    key: jax.Array,
) -> tuple[GFlowNet, jax.Array]:
    # Backward pass to sample backward trajectories
    (res, key), scanned_output = jax.lax.scan(
        backward_fn,
        (gflownet, key),
        length=gflownet.state.env_state.max_trajectory_length,
    )
    actions, is_active, states = scanned_output

    # We should reorder these columns properly
    last_active_idx = jnp.sum(is_active, axis=0)

    rolled_actions = roll_by_last_active(actions, last_active_idx)
    rolled_is_active = roll_by_last_active(is_active, last_active_idx)
    rolled_states = tree_util.tree_map(lambda x: roll_by_last_active(x, last_active_idx), states)

    # Forward pass to compute the forward probabilities
    res.lazy_init_fwd()
    _, log_pf = jax.lax.scan(
        partial(
            evaluate_on_transitions_fwd,
            actions=rolled_actions,
            is_active=rolled_is_active,
            states=rolled_states,
        ),
        init=(res, 0),
        length=res.state.env_state.max_trajectory_length,
    )

    res.state = res.state.replace(log_pf=log_pf.T)

    return res, key


def compute_state_prob(
    gflownet: GFlowNet,
    key: jax.Array,
    log_rewards: jax.Array,
    backward_fn: callable,
    num_trajectories: int,
):
    # We then compute the backward probability
    gflownet.lazy_init_bcw()

    @nnx.jit
    def body_fn(key: jax.Array, _) -> tuple[jax.Array, GFlowNetState]:
        new_gflownet, key = rollout_bcw(gflownet, backward_fn, key=key)
        new_state = new_gflownet.state
        return key, new_state

    _, states = jax.lax.scan(body_fn, key, length=num_trajectories)

    marginal_log_prob = (states.log_pf - states.log_pb).sum(axis=2)
    marginal_prob = jax.nn.logsumexp(marginal_log_prob, axis=0) - jnp.log(num_trajectories)

    marginal_prob = marginal_prob - jax.nn.logsumexp(marginal_prob, axis=0)
    marginal_true = log_rewards - jax.nn.logsumexp(log_rewards, axis=0)

    return marginal_prob, marginal_true
