from functools import partial

import flax.nnx as nnx
import flax.struct as struct
import jax
import jax.numpy as jnp
import tqdm

from divgfn.policies import GFlowNet, take
from divgfn.replay_buffer import ReplayBuffer
from divgfn.replay_buffer import push as push_to_buffer
from divgfn.replay_buffer import sample as sample_from_buffer
from divgfn.training import create_opt
from divgfn.utils import HasState, ModeQueue, RunOutput, merge, push


class EvalMetricCallable:
    def __init__(self, freq: int):
        self.freq = freq
        self.counter = 1
        self.metrics = []

    def eval_metric(self, gfn: GFlowNet):
        return


def stable_tb(
    log_pf_over_pb: jax.Array,
    logz: jax.Array,
    log_rewards: jax.Array,
    epsilon: float = 1e-3,
):
    # numerator in log-space
    log_num = log_pf_over_pb + logz
    log_eps_num = jnp.broadcast_to(jnp.log(epsilon), log_num.shape)
    log_num_eps = jax.nn.logsumexp(
        jnp.hstack([log_eps_num[..., None], log_num[..., None]]), axis=1
    )

    # denominator in log-space
    log_den = log_rewards
    log_eps_den = jnp.broadcast_to(jnp.log(epsilon), log_den.shape)
    log_den_eps = jax.nn.logsumexp(
        jnp.hstack([log_eps_den[..., None], log_den[..., None]]), axis=1
    )

    r_hat_minus_r_teach = log_num_eps - log_den_eps
    return r_hat_minus_r_teach


@struct.dataclass
class Config:
    din: int
    dout: int

    key: jax.Array

    use_scheduler: bool = True

    iterations: int = 512
    bs: int = 64
    dmid: int = 64
    nlayers: int = 2

    lr: float = 1e-3
    logz_lr: float = 1e-1

    lr_teacher: float = 1e-3
    logz_lr_teacher: float = 1e-1

    should_clip: bool = False

    buffer_size: int = 256
    buffer_freq: float = struct.field(pytree_node=False, default=-0.15)


@struct.dataclass
class EnvConfig:
    fstep: callable = struct.field(pytree_node=False)
    bstep: callable = struct.field(pytree_node=False)
    fpol: callable = struct.field(pytree_node=False)
    bpol: callable = struct.field(pytree_node=False)

    fapply: callable = struct.field(pytree_node=False)


@partial(jax.jit, static_argnames=("trajectories", "bstep"))
def compute_state_prob(
    samples: HasState,
    gfn: GFlowNet,
    key: jax.Array,
    log_rewards: jax.Array,
    s_o: HasState,
    trajectories: int,
    bstep: callable,
):
    def scan_fn(key: jax.Array, _):
        key, subkey = jax.random.split(key, 2)
        _, (flogits, blogits, _) = jax.lax.scan(
            f=partial(bstep, gfn=gfn),
            init=(samples, subkey),
            length=s_o.L,
        )
        return key, (flogits, blogits)

    # We evaluate GFN in samples
    key, (flogits, blogits) = jax.lax.scan(
        f=scan_fn,
        init=key,
        length=trajectories,
    )  # [trajectories, L, B]

    log_pt = (flogits - blogits).sum(axis=1)  # [trajectories, B]
    log_pt = jax.nn.logsumexp(log_pt, axis=0) - jnp.log(trajectories)  # [B,]

    log_pt = log_pt - jax.nn.logsumexp(log_pt, axis=0)
    log_pi = log_rewards - jax.nn.logsumexp(log_rewards, axis=0)

    return log_pt, log_pi


def tv_on_batch(
    key: jax.Array,
    _,
    s_o: HasState,
    gfn: GFlowNet,
    logr: nnx.Module,
    env: EnvConfig,
    trajectories: int = 32,
):
    # We compute the learning objective according to `sampler
    _, samples, key, log_rewards, _ = rollout(gfn, logr, s_o, key, env.fstep, eps=1)
    log_pt, log_pi = compute_state_prob(
        samples,
        gfn,
        key,
        log_rewards,
        s_o,
        trajectories=trajectories,
        bstep=env.bstep,
    )

    return key, 0.5 * jnp.abs(jnp.exp(log_pt) - jnp.exp(log_pi)).sum()


@partial(jax.jit, static_argnames=("iterations",))
def fcs(
    s_o: HasState,
    gfn: GFlowNet,
    logr: nnx.Module,
    key: jax.Array,
    iterations: int,
    env: EnvConfig,
):
    key, tv_per_batch = jax.lax.scan(
        partial(tv_on_batch, s_o=s_o, gfn=gfn, logr=logr, env=env),
        init=key,
        length=iterations,
    )
    return tv_per_batch.mean()


def run_div_gfn(
    config: Config,
    env: EnvConfig,
    logr: nnx.Module,
    s_o: HasState,
    queue: ModeQueue,
    buffer: ReplayBuffer = None,
    eval_metric: EvalMetricCallable = None,
):
    rngs = nnx.Rngs(config.key)
    gfn = GFlowNet(
        config.din, config.dmid, config.dout, nlayers=config.nlayers, rngs=rngs
    )
    gfn_teach = GFlowNet(
        config.din, config.dmid, config.dout, nlayers=config.nlayers, rngs=rngs
    )

    opt = create_opt(
        gfn,
        config.iterations,
        lr=config.lr,
        use_scheduler=config.use_scheduler,
        logz_lr=config.logz_lr,
        should_clip=config.should_clip,
    )
    opt_teach = create_opt(
        gfn_teach,
        config.iterations,
        lr=config.lr_teacher,
        use_scheduler=config.use_scheduler,
        logz_lr=config.logz_lr_teacher,
        should_clip=config.should_clip,
    )

    key = rngs.default()

    logzs = []
    hlogr = []
    hr = []

    for idx in (pbar := tqdm.trange(config.iterations)):
        loss, loss_teach, samples, log_rewards, key, buffer = train_step_div(
            gfn,
            gfn_teach,
            logr,
            s_o,
            key,
            opt,
            opt_teach,
            env,
            config,
            idx,
            buffer,
        )

        queue = push(queue, samples, log_rewards)

        if eval_metric is not None:
            eval_metric.eval_metric(gfn, gfn_teach)

        pbar.set_postfix(
            loss=f"{loss:.2e}",
            loss_teach=f"{loss_teach:.2e}",
            logz=f"{gfn.logz.get_value():.3e}",
            logz_teach=f"{gfn_teach.logz.get_value():.3e}",
        )
        logzs.append(gfn.logz.copy().item())
        hlogr.append(queue.hlogr.item())
        hr.append(queue.hr.item())

    extra_metrics = eval_metric.metrics if eval_metric is not None else []

    return RunOutput(
        s_o=s_o,
        gfn=gfn,
        gfn_teach=gfn_teach,
        logzs=logzs,
        queue=queue,
        hlogr=hlogr,
        hr=hr,
        extra_metrics=extra_metrics,
    )


def run_teacher_gfn(
    config: Config,
    env: EnvConfig,
    logr: nnx.Module,
    s_o: HasState,
    queue: ModeQueue,
    _: ReplayBuffer = None,
    eval_metric: EvalMetricCallable = None,
):
    rngs = nnx.Rngs(config.key)
    gfn = GFlowNet(
        config.din, config.dmid, config.dout, nlayers=config.nlayers, rngs=rngs
    )
    gfn_teach = GFlowNet(
        config.din, config.dmid, config.dout, nlayers=config.nlayers, rngs=rngs
    )

    opt = create_opt(
        gfn,
        config.iterations,
        use_scheduler=config.use_scheduler,
        lr=config.lr,
        logz_lr=config.logz_lr,
        should_clip=config.should_clip,
    )
    # Too unstable for higher learning rate
    opt_teach = create_opt(
        gfn_teach,
        config.iterations,
        use_scheduler=config.use_scheduler,
        lr=config.lr_teacher,
        logz_lr=config.logz_lr_teacher,
        should_clip=config.should_clip,
    )

    key = rngs.default()

    logzs = []
    hlogr = []
    hr = []

    for _ in (pbar := tqdm.trange(config.iterations)):
        loss_teach, loss, samples, log_rewards, key = train_step_teacher(
            gfn,
            gfn_teach,
            logr,
            s_o,
            key,
            opt,
            opt_teach,
            env,
        )

        queue = push(queue, samples, log_rewards)

        if eval_metric is not None:
            eval_metric.eval_metric(gfn, gfn_teach)

        pbar.set_postfix(
            loss=f"{loss:.2e}",
            loss_teach=f"{loss_teach:.2e}",
            logz=f"{gfn.logz.get_value():.3e}",
            logz_teach=f"{gfn_teach.logz.get_value():.3e}",
        )
        logzs.append(gfn.logz.copy().item())
        hlogr.append(queue.hlogr.item())
        hr.append(queue.hr.item())

    extra_metrics = eval_metric.metrics if eval_metric is not None else []

    return RunOutput(
        s_o=s_o,
        gfn=gfn,
        gfn_teach=gfn_teach,
        logzs=logzs,
        queue=queue,
        hlogr=hlogr,
        hr=hr,
        extra_metrics=extra_metrics,
    )


@nnx.jit
def eval_from_buffer(
    env: EnvConfig,
    config: Config,
    gfn: GFlowNet,
    s_o: HasState,
    logr: nnx.Module,
    optimizer: nnx.Optimizer,
    key: jax.Array,
    iteration: int,
    eps: float = 5e-2,
    buffer: ReplayBuffer = None,
):
    bs, _ = s_o.state.shape

    def loss_fn(gfn: GFlowNet):
        new_key, subkey = jax.random.split(key, 2)

        loss, samples, new_key, log_rewards, _ = rollout(
            gfn, logr, s_o, new_key, eps=eps, fstep=env.fstep
        )

        if buffer is not None and config.buffer_freq > 0:
            new_buffer = push_to_buffer(buffer, samples, log_rewards, iteration)

            samples_from_buffer, log_rewards_from_buffer, new_key = sample_from_buffer(
                new_buffer, bs, subkey
            )

            (_, new_key), (flogits_from_buffer, blogits_from_buffer, _) = jax.lax.scan(
                f=partial(env.bstep, gfn=gfn),
                length=s_o.L,
                init=(samples_from_buffer, new_key),
            )
            # We also sample from the GFlowNet
            loss_from_buffer = (
                gfn.logz
                + (flogits_from_buffer - blogits_from_buffer).sum(axis=0)
                - log_rewards_from_buffer
            ) ** 2
            loss_from_buffer = loss_from_buffer.mean()
            loss = (
                1 - config.buffer_freq
            ) * loss + config.buffer_freq * loss_from_buffer
        else:
            new_buffer = None

        return loss, (samples, log_rewards, new_key, new_buffer)

    (loss, (samples, log_rewards, key, new_buffer)), grads = nnx.value_and_grad(
        loss_fn, has_aux=True, argnums=0
    )(gfn)

    optimizer.update(gfn, grads)
    return loss, samples, log_rewards, key, new_buffer


def run_gfn(
    config: Config,
    env: EnvConfig,
    logr: nnx.Module,
    s_o: HasState,
    queue: ModeQueue,
    buffer: ReplayBuffer = None,
    eval_metric: EvalMetricCallable = None,
):
    rngs = nnx.Rngs(config.key)
    gfn = GFlowNet(
        config.din, config.dmid, config.dout, nlayers=config.nlayers, rngs=rngs
    )

    opt = create_opt(
        gfn,
        config.iterations,
        use_scheduler=config.use_scheduler,
        lr=config.lr,
        logz_lr=config.logz_lr,
        should_clip=config.should_clip,
    )

    key = rngs.default()

    logzs = []
    hlogr = []
    hr = []

    for idx in (pbar := tqdm.trange(config.iterations)):
        loss, samples, log_rewards, key, buffer = eval_from_buffer(
            env,
            config,
            gfn,
            s_o,
            logr,
            opt,
            key,
            idx,
            buffer=buffer,
        )

        queue = push(queue, samples, log_rewards)

        if eval_metric is not None:
            eval_metric.eval_metric(gfn, None)

        pbar.set_postfix(
            loss=f"{loss:.2e}",
            logz=f"{gfn.logz.get_value():.3e}",
        )
        logzs.append(gfn.logz.copy().item())
        hlogr.append(queue.hlogr.item())
        hr.append(queue.hr.item())

    extra_metrics = eval_metric.metrics if eval_metric is not None else []

    return RunOutput(
        s_o=s_o,
        gfn=gfn,
        gfn_teach=None,
        logzs=logzs,
        queue=queue,
        hlogr=hlogr,
        hr=hr,
        extra_metrics=extra_metrics,
    )


class RNDNet(nnx.Module):
    def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
        # This is the learned network
        self.linear1 = nnx.Linear(din, dmid, rngs=rngs, use_bias=False)
        self.linear2 = nnx.Linear(dmid, dmid, rngs=rngs, use_bias=False)
        self.linear3 = nnx.Linear(dmid, dout, rngs=rngs, use_bias=False)

        self.learner = nnx.Sequential(
            self.linear1, nnx.leaky_relu, self.linear2, nnx.leaky_relu, self.linear3
        )

        # This is the random network
        initializer = nnx.initializers.lecun_normal()
        self.linear_rnd1 = initializer(rngs(), shape=(din, dmid))
        self.linear_rnd2 = initializer(rngs(), shape=(dmid, dmid))
        self.linear_rnd3 = initializer(rngs(), shape=(dmid, dout))

    def target_net(self, x: jax.Array):
        y = x @ self.linear_rnd1
        y = nnx.leaky_relu(y)
        y = y @ self.linear_rnd2
        y = nnx.leaky_relu(y)
        y = y @ self.linear_rnd3
        return y

    def __call__(self, x: HasState):
        y_learned = self.learner(x.state)
        y_target = self.target_net(x.state)
        val = jnp.power(y_learned - y_target, 2).sum(axis=1)
        return val


def run_sa_gfn(
    config: Config,
    env: EnvConfig,
    logr: nnx.Module,
    s_o: HasState,
    queue: ModeQueue,
    _: ReplayBuffer = None,
    eval_metric: EvalMetricCallable = None,
):
    rngs = nnx.Rngs(config.key)
    gfn = GFlowNet(
        config.din, config.dmid, config.dout, nlayers=config.nlayers, rngs=rngs
    )
    gfn_sibling = GFlowNet(
        config.din, config.dmid, config.dout, nlayers=config.nlayers, rngs=rngs
    )
    rnd_net = RNDNet(config.din, config.dmid, 1, rngs=rngs)

    # Create optimizers for each model
    opt = create_opt(
        gfn,
        config.iterations,
        use_scheduler=config.use_scheduler,
        logz_lr=config.logz_lr,
        lr=config.lr,
        should_clip=config.should_clip,
    )
    opt_sibling = create_opt(
        gfn_sibling,
        config.iterations,
        use_scheduler=config.use_scheduler,
        logz_lr=config.logz_lr_teacher,
        lr=config.lr_teacher,
        should_clip=config.should_clip,
    )
    opt_rnd = create_opt(
        rnd_net,
        config.iterations,
        should_clip=config.should_clip,
        use_scheduler=True,
        lr=1e-2,
    )

    # Configuration for the SA networks (default values)
    sa_config = SAConfig()

    key = rngs.default()

    logzs = []
    hlogr = []
    hr = []

    for _ in (pbar := tqdm.trange(config.iterations)):
        loss, loss_sibling, loss_rnd, samples, log_rewards, key = train_step_sa(
            gfn_sibling=gfn_sibling,
            gfn=gfn,
            rnd_net=rnd_net,
            logr=logr,
            s_o=s_o,
            key=key,
            opt=opt,
            opt_sibling=opt_sibling,
            opt_rnd=opt_rnd,
            env=env,
            sa_config=sa_config,
        )

        queue = push(queue, samples, log_rewards)

        if eval_metric is not None:
            eval_metric.eval_metric(gfn, gfn_sibling)

        pbar.set_postfix(
            loss=f"{loss:.2e}",
            loss_sibling=f"{loss_sibling:.2e}",
            loss_rnd=f"{loss_rnd:.2e}",
            logz=f"{gfn.logz.get_value():.3e}",
        )
        logzs.append(gfn.logz.copy().item())
        hlogr.append(queue.hlogr.item())
        hr.append(queue.hr.item())

    extra_metrics = eval_metric.metrics if eval_metric is not None else []

    return RunOutput(
        s_o=s_o,
        gfn=gfn,
        gfn_teach=gfn_sibling,
        queue=queue,
        logzs=logzs,
        hlogr=hlogr,
        hr=hr,
        extra_metrics=extra_metrics,
    )


def run_random_sampler(
    config: Config, env: EnvConfig, logr: nnx.Module, s_o: HasState, queue: ModeQueue
):
    # Run a randomly initialized (untrained) sampler as a baseline
    rngs = nnx.Rngs(config.key)
    gfn = GFlowNet(
        config.din, config.dmid, config.dout, nlayers=config.nlayers, rngs=rngs
    )

    key = rngs.default()
    hlogr = []
    hr = []

    rollout_jit = jax.jit(partial(rollout, fstep=env.fstep))
    for _ in tqdm.trange(config.iterations):
        _, samples, key, log_rewards, _ = rollout_jit(gfn, logr, s_o, key)

        queue = push(queue, samples, log_rewards)
        hlogr.append(queue.hlogr.item())
        hr.append(queue.hr.item())

    return RunOutput(
        s_o=s_o, gfn=gfn, gfn_teach=None, queue=queue, logzs=[], hlogr=hlogr, hr=hr
    )


def eval_policy(
    state: HasState,
    actions: jax.Array,
    gfn: GFlowNet,
    fapply: callable,
    fpol: callable,
    bpol: callable,
):
    # This evaluates us to evaluate a policy on a trajectory
    # Use: jax.lax.scan(..., xs=actions, init_val=initial_state)
    flogits = fpol(state, gfn)
    flogits = take(flogits, actions)
    factive = ~state.fstopped

    fstate = fapply(state, actions)

    blogits = bpol(fstate, gfn)
    blogits = take(blogits, actions)

    flogits = jnp.where(factive, flogits, 0)
    blogits = jnp.where(factive, blogits, 0)

    return fstate, (flogits, blogits)


def rollout(
    gfn: nnx.Module,
    log_reward: nnx.Module,
    s_o: HasState,
    key: jax.Array,
    fstep: callable,
    eps: float = 0.0,
):
    # Sample from gfn
    (samples, key), (flogits, blogits, actions) = jax.lax.scan(
        f=partial(fstep, gfn=gfn, eps=eps),
        init=(s_o, key),
        length=s_o.L,
    )
    # Compute the log reward of the samples
    log_rewards = log_reward(samples)
    # Compute the loss function
    loss = ((flogits - blogits).sum(axis=0) - log_rewards + gfn.logz) ** 2
    loss = loss.mean()

    return loss, samples, key, log_rewards, actions


def rollout_div(
    gfn_teach: nnx.Module,
    gfn: nnx.Module,
    log_reward: nnx.Module,
    s_o: HasState,
    key: jax.Array,
    fstep: callable,
    fapply: callable,
    fpol: callable,
    bpol: callable,
    bstep: callable,
    alpha: float,
    beta: float,
):
    # We first sample from the exploited GFlowNet
    (samples, key), (flogits, blogits, actions) = jax.lax.scan(
        f=partial(fstep, gfn=gfn),
        init=(s_o, key),
        length=s_o.L,
    )

    # Evaluate the teacher model on the student's samples
    _, (flogits_on_student, blogits_on_student) = jax.lax.scan(
        f=partial(eval_policy, gfn=gfn_teach, fapply=fapply, bpol=bpol, fpol=fpol),
        init=samples,
        xs=actions,
        length=s_o.L,
    )

    # Sample from gfn
    (samples_teach, key), (flogits_teach, blogits_teach, _) = jax.lax.scan(
        f=partial(fstep, gfn=gfn_teach),
        init=(s_o, key),
        length=s_o.L,
    )

    # Evaluate the student model on the teacher's samples
    (_, key), (flogits_on_teach, blogits_on_teach, _) = jax.lax.scan(
        f=partial(bstep, gfn=gfn),
        init=(samples_teach, key),
        length=s_o.L,
    )

    # Compute the log reward of the samples
    log_rewards_on_student = log_reward(samples)
    log_rewards_on_teach = log_reward(samples_teach)

    all_samples = jax.tree_util.tree_map(merge, samples, samples_teach)
    all_log_rewards = jnp.hstack([log_rewards_on_student, log_rewards_on_teach])

    def teacher_loss_function(
        flogits_teach: jax.Array,
        blogits_teach: jax.Array,
        flogits_student: jax.Array,
        blogits_student: jax.Array,
        log_rewards: jax.Array,
    ):
        r_hat_minus_r = (
            gfn.logz + (flogits_student - blogits_student).sum(axis=0) - log_rewards
        )
        th = jnp.log(alpha)
        rho = jnp.where(r_hat_minus_r > th, 1, 0)
        rho = jax.lax.stop_gradient(rho)
        rho = rho.astype(bool)
        # Compute the loss function for the teacher
        log_pf_over_pb = (flogits_teach - blogits_teach).sum(axis=0)
        r_hat_minus_r_teach = stable_tb(
            log_pf_over_pb,
            logz=gfn_teach.logz,
            log_rewards=beta * log_rewards,
        )
        # r_hat_minus_r_teach = log_pf_over_pb - beta * log_rewards + gfn_teach.logz
        loss_teach = (
            jnp.where(rho, nnx.softplus(r_hat_minus_r_teach), r_hat_minus_r_teach) ** 2
        )

        return loss_teach.mean()

    loss_teach = teacher_loss_function(
        flogits_teach,
        blogits_teach,
        flogits_on_teach,
        blogits_on_teach,
        log_rewards_on_teach,
    )

    # loss_teach_on_student = teacher_loss_function(
    #     flogits_on_student,
    #     blogits_on_student,
    #     flogits_on_teach,
    #     blogits_on_teach,
    #     log_rewards_on_student,
    # )

    # w = nnx.sigmoid(gfn.logz - gfn_teach.logz)
    # w = jax.lax.stop_gradient(w)
    # # loss_teach = w * loss_teach + (1 - w) * loss_teach_on_student

    loss = (gfn.logz + (flogits - blogits).sum(axis=0) - log_rewards_on_student) ** 2
    loss_from_teach = (
        gfn.logz
        + (flogits_on_teach - blogits_on_teach).sum(axis=0)
        - log_rewards_on_teach
    ) ** 2

    return (
        loss_teach,
        loss.mean(),
        loss_from_teach.mean(),
        all_samples,
        key,
        all_log_rewards,
    )


def rollout_teacher(
    gfn_teach: nnx.Module,
    gfn: nnx.Module,
    log_reward: nnx.Module,
    s_o: HasState,
    key: jax.Array,
    fstep: callable,
    bstep: callable,
    fapply: callable,
    fpol: callable,
    bpol: callable,
    alpha: float = 0.5,
):
    # We randomly decide whether to sample from the student or teacher
    key, subkey = jax.random.split(key, 2)
    sample_from_student = jax.random.bernoulli(subkey) == 1

    # Sample from gfn
    sampler_gfn = jax.lax.cond(sample_from_student, lambda: gfn, lambda: gfn_teach)
    evaluator_gfn = jax.lax.cond(sample_from_student, lambda: gfn_teach, lambda: gfn)

    # Generate samples from the sampler
    (samples, key), (sampler_flogits, sampler_blogits, actions) = jax.lax.scan(
        f=partial(fstep, gfn=sampler_gfn),
        init=(s_o, key),
        length=s_o.L,
    )

    # Given the state, sample a backward trajectory and evaluate the student
    (_, key), (student_flogits, student_blogits, _) = jax.lax.scan(
        f=partial(bstep, gfn=gfn),
        init=(samples, key),
        length=s_o.L,
    )  # Will be used for computing the teacher's reward function

    # Evaluate the GFlowNet on the sampled trajectories
    _, (evaluator_flogits, evaluator_blogits) = jax.lax.scan(
        f=partial(eval_policy, gfn=evaluator_gfn, fapply=fapply, fpol=fpol, bpol=bpol),
        xs=actions,
        init=s_o,
        length=s_o.L,
    )  # Compute the loss function for the GFlowNet

    # Assign logits based on sampling strategy
    flogits_teach, flogits, blogits_teach, blogits = jax.lax.cond(
        sample_from_student,
        lambda: (
            evaluator_flogits,
            sampler_flogits,
            evaluator_blogits,
            sampler_blogits,
        ),
        lambda: (
            sampler_flogits,
            evaluator_flogits,
            sampler_blogits,
            evaluator_blogits,
        ),
    )

    # Compute the log reward of the samples
    log_rewards = log_reward(samples)

    # Compute the quotient ratio
    r_hat_minus_r = gfn.logz + (flogits - blogits).sum(axis=0) - log_rewards
    loss = r_hat_minus_r**2
    loss = loss.mean()

    # See https://arxiv.org/pdf/2410.01432
    # Contemple as maquinações dos coreanos.
    # = - \delta(\tau ; \theta)
    r_hat_minus_r_on_backward = (
        gfn.logz + (student_flogits - student_blogits).sum(axis=0) - log_rewards
    )
    loss_on_backward = r_hat_minus_r_on_backward**2
    weights = 1 + 19 * (r_hat_minus_r_on_backward < 0)  # C = 19
    log_r_teach = jax.lax.stop_gradient(
        jnp.log(1e-10 + weights * loss_on_backward) + alpha * log_rewards
    )

    # Compute the loss function
    loss_teach = (
        (flogits_teach - blogits_teach).sum(axis=0) - log_r_teach + gfn_teach.logz
    ) ** 2
    loss_teach = loss_teach.mean()

    return loss_teach, loss, samples, log_rewards, key


@struct.dataclass
class SAConfig:
    beta_e_bn = 1
    beta_e_sn = 0.25

    beta_i = 1
    beta_sn = 1
    beta_bn = 1


def rnd(
    state: HasState,
    actions: jax.Array,
    rnd_net: nnx.Module,
    fapply: callable,
):
    # We apply the actions to the current state
    fstate = fapply(state, actions)
    intrinsic_r = rnd_net(fstate)
    return fstate, intrinsic_r


def rollout_sa(
    gfn_sibling: nnx.Module,
    gfn: nnx.Module,
    rnd_net: nnx.Module,
    log_reward: nnx.Module,
    s_o: HasState,
    key: jax.Array,
    fstep: callable,
    fapply: callable,
    fpol: callable,
    bpol: callable,
    sa_config: SAConfig,
):
    # Sample from Sibling network
    (samples_sibling, key), (sibling_flogits, sibling_blogits, sibling_actions) = (
        jax.lax.scan(
            f=partial(fstep, gfn=gfn_sibling),
            init=(s_o, key),
            length=s_o.L,
        )
    )

    # Evaluate the random distillation network
    _, intrinsic_r = jax.lax.scan(
        f=partial(rnd, rnd_net=rnd_net, fapply=fapply),
        xs=sibling_actions,
        init=s_o,
        length=s_o.L,
    )
    # (T, B) -> (B,)
    intrinsic_r = intrinsic_r.sum(axis=0)
    log_intrinsic_r = sa_config.beta_i * jnp.log(intrinsic_r)  # (B,)
    # Compute the intrinsic rewards
    log_rewards_sibling = log_reward(samples_sibling)

    # Compute the compounded reward of the sibling network
    log_sibling_r = sa_config.beta_sn * jax.nn.logsumexp(
        jnp.hstack(
            [
                log_intrinsic_r[..., None],
                sa_config.beta_e_sn * log_rewards_sibling[..., None],
            ]
        ),
        axis=1,
    )
    log_sibling_r = jax.lax.stop_gradient(log_sibling_r)

    # Compute the loss function for the sibling network
    r_hat_minus_r_sibling = (
        gfn_sibling.logz
        + (sibling_flogits - sibling_blogits).sum(axis=0)
        - log_sibling_r
    )
    loss_sibling = (r_hat_minus_r_sibling**2).mean()

    # Sample trajectories from the behavior network
    (samples, key), (flogits, blogits, _) = jax.lax.scan(
        f=partial(fstep, gfn=gfn),
        init=(s_o, key),
        length=s_o.L,
    )

    # Compute the log reward function for samples
    log_rewards = log_reward(samples)

    # Evaluate the current network on the Sibling's policies
    _, (flogits_on_sibling, blogits_on_sibling) = jax.lax.scan(
        f=partial(
            eval_policy,
            gfn=gfn,
            fapply=fapply,
            fpol=fpol,
            bpol=bpol,
        ),
        xs=sibling_actions,
        init=s_o,
    )

    # Evaluate the loss function of the Behavior network on the Sibling and Behavior samples
    loss_on_sibling = (
        gfn.logz
        + (flogits_on_sibling - blogits_on_sibling).sum(axis=0)
        - log_rewards_sibling
    ) ** 2
    loss_on_behavior = (gfn.logz + (flogits - blogits).sum(axis=0) - log_rewards) ** 2

    loss = 0.5 * (loss_on_sibling + loss_on_behavior).mean()
    loss_rnd = intrinsic_r.mean()

    all_samples = jax.tree_util.tree_map(merge, samples, samples_sibling)
    all_log_rewards = jnp.hstack([log_rewards, log_rewards_sibling])

    return loss, loss_sibling, all_log_rewards, all_samples, loss_rnd, key


@nnx.jit
def train_step_div(
    gfn: GFlowNet,
    gfn_teach: GFlowNet,
    logr: nnx.Module,
    s_o: HasState,
    key: jax.Array,
    opt: nnx.Optimizer,
    opt_teach: nnx.Optimizer,
    env: EnvConfig,
    config: Config,
    iteration: int,
    buffer: ReplayBuffer = None,
    alpha: float = 0.3,
    beta: float = 0.25,
):
    def loss_fn(gfn: GFlowNet, gfn_teach: GFlowNet):
        loss_teach, loss, loss_from_teach, samples, new_key, log_rewards = rollout_div(
            gfn_teach,
            gfn,
            logr,
            s_o,
            key,
            fstep=env.fstep,
            bstep=env.bstep,
            fapply=env.fapply,
            fpol=env.fpol,
            bpol=env.bpol,
            alpha=alpha,
            beta=beta,
        )
        w = nnx.sigmoid(gfn.logz - gfn_teach.logz)
        w = jax.lax.stop_gradient(w)
        # When taking gradients, this should properly represent the gradients for each model

        loss = w * loss + (1 - w) * loss_from_teach

        if buffer is not None and config.buffer_freq > 0:
            # Push samples to buffer.
            new_buffer = push_to_buffer(buffer, samples, log_rewards, iteration)
            # Sample from buffer and evaluate the loss therein.
            new_key, subkey = jax.random.split(new_key, 2)
            buffer_samples, buffer_log_rewards, new_key = sample_from_buffer(
                new_buffer, len(log_rewards), subkey
            )

            # Evaluate the exploitation GFlowNet on these samples
            (_, new_key), (buffer_flogits, buffer_blogits, _) = jax.lax.scan(
                f=partial(env.bstep, gfn=gfn),
                init=(buffer_samples, new_key),
                length=s_o.L,
            )

            buffer_delta = (
                gfn.logz
                + (buffer_flogits - buffer_blogits).sum(axis=0)
                - buffer_log_rewards
            )
            buffer_loss = (buffer_delta**2).mean()
            loss = (1 - config.buffer_freq) * loss + config.buffer_freq * buffer_loss

            buffer_loss_teach = (
                gfn_teach.logz
                + (buffer_flogits - buffer_blogits).sum(axis=0)
                - buffer_log_rewards
            )
            rho = jax.lax.stop_gradient(buffer_delta > alpha)
            buffer_loss_teach = (
                jnp.where(rho, nnx.softplus(buffer_loss_teach), buffer_loss_teach) ** 2
            )
            buffer_loss_teach = buffer_loss_teach.mean()
            loss_teach = (
                1 - config.buffer_freq
            ) * loss_teach + config.buffer_freq * buffer_loss_teach
        else:
            new_buffer = None

        return loss + loss_teach, {
            "loss": loss,
            "loss_from_teach": loss_from_teach,
            "w": w,
            "loss_teach": loss_teach,
            "samples": samples,
            "new_key": new_key,
            "log_rewards": log_rewards,
            "new_buffer": new_buffer,
        }

    (_, out), (grads, grads_teach) = nnx.value_and_grad(
        loss_fn,
        has_aux=True,
        argnums=(0, 1),
    )(gfn, gfn_teach)

    opt.update(gfn, grads)
    opt_teach.update(gfn_teach, grads_teach)

    return (
        out["loss"],
        out["loss_teach"],
        out["samples"],
        out["log_rewards"],
        out["new_key"],
        out["new_buffer"],
    )


@nnx.jit
def train_step_teacher(
    gfn: GFlowNet,
    gfn_teach: GFlowNet,
    logr: nnx.Module,
    s_o: HasState,
    key: jax.Array,
    opt: nnx.Optimizer,
    opt_teach: nnx.Optimizer,
    env: EnvConfig,
):
    def loss_fn(gfn: GFlowNet, gfn_teach: GFlowNet):
        loss_teach, loss, samples, log_rewards, new_key = rollout_teacher(
            gfn_teach,
            gfn,
            logr,
            s_o,
            key,
            fstep=env.fstep,
            bstep=env.bstep,
            fapply=env.fapply,
            fpol=env.fpol,
            bpol=env.bpol,
        )
        return loss_teach + loss, (loss_teach, loss, samples, log_rewards, new_key)

    (_, (loss_teach, loss, samples, log_rewards, new_key)), (grads, grads_teach) = (
        nnx.value_and_grad(
            loss_fn,
            has_aux=True,
            argnums=(0, 1),
        )(gfn, gfn_teach)
    )

    opt.update(gfn, grads)
    opt_teach.update(gfn_teach, grads_teach)

    return loss_teach, loss, samples, log_rewards, new_key


@nnx.jit
def train_step_sa(
    gfn_sibling: GFlowNet,
    gfn: GFlowNet,
    rnd_net: nnx.Module,
    logr: nnx.Module,
    s_o: HasState,
    key: jax.Array,
    opt: nnx.Optimizer,
    opt_sibling: nnx.Optimizer,
    opt_rnd: nnx.Optimizer,
    env: EnvConfig,
    sa_config: SAConfig,
):
    def loss_fn(gfn_sibling: GFlowNet, gfn: GFlowNet, rnd_net: nnx.Module):
        loss, loss_sibling, log_rewards, samples, loss_rnd, new_key = rollout_sa(
            gfn_sibling=gfn_sibling,
            gfn=gfn,
            rnd_net=rnd_net,
            log_reward=logr,
            s_o=s_o,
            key=key,
            fstep=env.fstep,
            fapply=env.fapply,
            fpol=env.fpol,
            bpol=env.bpol,
            sa_config=sa_config,
        )
        out = {
            "loss": loss,
            "loss_sibling": loss_sibling,
            "samples": samples,
            "log_rewards": log_rewards,
            "loss_rnd": loss_rnd,
            "new_key": new_key,
        }
        loss_sum = loss + loss_sibling + loss_rnd
        return loss_sum, out

    (_, out), (grads_sibling, grads, grads_rnd) = nnx.value_and_grad(
        loss_fn,
        argnums=(0, 1, 2),
        has_aux=True,
    )(gfn_sibling, gfn, rnd_net)

    opt.update(gfn, grads)
    opt_sibling.update(gfn_sibling, grads_sibling)
    opt_rnd.update(rnd_net, grads_rnd)

    return (
        out["loss"],
        out["loss_sibling"],
        out["loss_rnd"],
        out["samples"],
        out["log_rewards"],
        out["new_key"],
    )


@struct.dataclass
class MainOutput:
    std: RunOutput
    div: RunOutput
    teacher: RunOutput
    sa: RunOutput
