from flax import nnx
from jax import Array, jit, lax, numpy as jnp

from offline.bppo.tc.core import high_level_loss_fn as qcritic_loss_fn
from offline.bppo.tc.iql.types import HighLevelTrainState
from offline.modules.base import TargetModel
from offline.modules.critic import VCritic
from offline.modules.mlp import MLP
from offline.types import BoolArray


def vcritic_loss_fn(
    vcritic: VCritic, expectile: float, observations: Array, qvalues: Array
):
    values = vcritic(observations)
    advantages = qvalues - values
    weights = lax.select(
        advantages > 0,
        jnp.full_like(advantages, fill_value=expectile),
        jnp.full_like(advantages, fill_value=1 - expectile),
    )
    loss = jnp.mean(weights * jnp.square(advantages))
    return loss, {"loss/HLV": loss, "pretrain/HLV": jnp.mean(values)}


def _qcritic_train_step(
    assignments: Array,
    dones: BoolArray,
    gamma: float,
    next_observations: Array,
    observations: Array,
    optimizer: nnx.Optimizer,
    qcritic: MLP,
    rewards: Array,
    vcritic: VCritic,
):
    vcritic.eval()
    targets = vcritic(next_observations)
    vcritic.train()
    targets = gamma * targets * (1 - dones) + rewards
    grad_fn = nnx.grad(qcritic_loss_fn, has_aux=True)
    grads, results = grad_fn(qcritic, assignments, observations, targets)
    optimizer.update(grads)
    return results


def _vcritic_train_step(
    assignments: Array,
    expectile: float,
    observations: Array,
    optimizer: nnx.Optimizer,
    target_qcritic: TargetModel[MLP],
    vcritic: VCritic,
):
    # [..., codebook_size]
    qvalues = target_qcritic.model(observations)
    # [...]
    qvalues = jnp.take_along_axis(
        qvalues, jnp.expand_dims(assignments, -1), axis=-1
    )
    grad_fn = nnx.grad(vcritic_loss_fn, has_aux=True)
    grads, results = grad_fn(vcritic, expectile, observations, qvalues)
    optimizer.update(grads)
    return results


def _train_step(
    assignments: Array,
    assignments_vcritic: Array,
    dones: BoolArray,
    expectile: float,
    gamma: float,
    next_observations: Array,
    observations: Array,
    observations_vcritic: Array,
    rewards: Array,
    train_state: HighLevelTrainState,
):
    results_vcritic = _vcritic_train_step(
        assignments=assignments_vcritic,
        expectile=expectile,
        observations=observations_vcritic,
        optimizer=train_state.optimizer_vcritic,
        target_qcritic=train_state.target_qcritic,
        vcritic=train_state.vcritic,
    )
    results_qcritic = _qcritic_train_step(
        assignments=assignments,
        dones=dones,
        gamma=gamma,
        next_observations=next_observations,
        observations=observations,
        optimizer=train_state.optimizer_qcritic,
        qcritic=train_state.qcritic,
        rewards=rewards,
        vcritic=train_state.vcritic,
    )
    return results_qcritic | results_vcritic


@jit
def high_level_train_step(
    assignments: Array,
    assignments_vcritic: Array,
    dones: BoolArray,
    expectile: float,
    gamma: float,
    graphdef: nnx.GraphDef[HighLevelTrainState],
    graphstate: nnx.GraphState | nnx.VariableState,
    next_observations: Array,
    observations: Array,
    observations_vcritic: Array,
    rewards: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    results = _train_step(
        assignments=assignments,
        assignments_vcritic=assignments_vcritic,
        dones=dones,
        expectile=expectile,
        gamma=gamma,
        next_observations=next_observations,
        observations=observations,
        observations_vcritic=observations_vcritic,
        rewards=rewards,
        train_state=train_state,
    )
    _, graphstate = nnx.split(train_state)
    return graphstate, results


@jit
def high_level_train_step_with_target_update(
    assignments: Array,
    assignments_vcritic: Array,
    dones: BoolArray,
    expectile: float,
    gamma: float,
    graphdef: nnx.GraphDef[HighLevelTrainState],
    graphstate: nnx.GraphState | nnx.VariableState,
    next_observations: Array,
    observations: Array,
    observations_vcritic: Array,
    rewards: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    results = _train_step(
        assignments=assignments,
        assignments_vcritic=assignments_vcritic,
        dones=dones,
        expectile=expectile,
        gamma=gamma,
        next_observations=next_observations,
        observations=observations,
        observations_vcritic=observations_vcritic,
        rewards=rewards,
        train_state=train_state,
    )
    train_state.target_qcritic.hard_update(train_state.qcritic)
    _, graphstate = nnx.split(train_state)
    return graphstate, results
