from typing import Any

from flax import nnx
from jax import Array, jit, lax, numpy as jnp

from offline.hdr.modules import ActorFilter, HDRPolicy, HDRTrainState
from offline.modules.actor.utils import Actor, compute_actions
from offline.modules.critic import QCriticEnsemble
from offline.td3bc.core import critic_loss_fn
from offline.types import BoolArray


def actor_loss_fn(
    policy: HDRPolicy,
    key: Array,
    key_data: Any,
    means: Array,
    observations: Array,
    stds: Array,
    threshold: float,
):
    deltas = compute_actions(
        actor=policy.actor,
        key=key,
        key_data=key_data,
        observations=observations,
        squash=False,
    )
    actions = means + deltas
    temp = jnp.sum(jnp.square(deltas / stds), axis=-1)
    mask = lax.stop_gradient(temp > threshold)
    negative_log_likelihood = 0.5 * temp - jnp.sum(jnp.log(stds), axis=-1)
    qvalues = jnp.min(policy.critic(observations, actions), axis=0)
    loss = jnp.mean(lax.select(mask, negative_log_likelihood, -qvalues))
    return loss, {"loss/actor": loss, "train/mask": jnp.mean(mask)}


def compute_targets(
    actor: Actor,
    critic: QCriticEnsemble,
    dones: BoolArray,
    gamma: float,
    key: Array,
    key_data: Any,
    next_means: Array,
    next_observations: Array,
    next_stds: Array,
    next_values: Array,
    rewards: Array,
    threshold: float,
):
    next_deltas = compute_actions(
        actor=actor,
        key=key,
        key_data=key_data,
        observations=next_observations,
        squash=False,
    )
    mask = jnp.sum(jnp.square(next_deltas / next_stds), axis=-1) > threshold
    next_actions = next_means + next_deltas
    next_qvalues = jnp.min(critic(next_observations, next_actions), axis=0)
    targets = lax.select(mask, next_values, next_qvalues)
    targets = gamma * targets * (1 - dones) + rewards
    targets = jnp.expand_dims(targets, axis=0)
    return targets


def train_actor_step(
    actor_optimizer: nnx.Optimizer,
    key: Array,
    key_data: Any,
    means: Array,
    observations: Array,
    policy: HDRPolicy,
    stds: Array,
    threshold: float,
) -> dict[str, Array]:
    diff_state = nnx.DiffState(0, ActorFilter)
    grad_fn = nnx.grad(actor_loss_fn, argnums=diff_state, has_aux=True)
    grads, results = grad_fn(
        policy,
        key=key,
        key_data=key_data,
        means=means,
        observations=observations,
        stds=stds,
        threshold=threshold,
    )
    actor_optimizer.update(grads)
    return results


@jit
def train_critic_step(
    actions: Array,
    dones: BoolArray,
    gamma: float,
    graphdef: nnx.GraphDef[HDRTrainState],
    graphstate: nnx.GraphState | nnx.VariableState,
    next_means: Array,
    next_observations: Array,
    next_stds: Array,
    next_values: Array,
    observations: Array,
    ood_threshold: float,
    rewards: Array,
    step: int,
    train_critic_key: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    targets = compute_targets(
        actor=train_state.target_policy.model.actor,
        critic=train_state.target_policy.model.critic,
        dones=dones,
        gamma=gamma,
        key=train_critic_key,
        key_data=step,
        next_means=next_means,
        next_observations=next_observations,
        next_stds=next_stds,
        next_values=next_values,
        rewards=rewards,
        threshold=ood_threshold,
    )
    grad_fn = nnx.grad(critic_loss_fn, has_aux=True)
    grads, results = grad_fn(
        train_state.policy.critic,
        actions=actions,
        observations=observations,
        targets=targets,
    )
    train_state.critic_optimizer.update(grads)
    _, graphstate = nnx.split(train_state)
    return graphstate, results


@jit
def train_actor_critic_step(
    actions: Array,
    dones: BoolArray,
    gamma: float,
    graphdef: nnx.GraphDef[HDRTrainState],
    graphstate: nnx.GraphState | nnx.VariableState,
    means_actor: Array,
    next_means: Array,
    next_observations: Array,
    next_stds: Array,
    next_values: Array,
    observations: Array,
    observations_actor: Array,
    ood_threshold: float,
    rewards: Array,
    stds_actor: Array,
    step: int,
    tau: float,
    train_actor_key: Array,
    train_critic_key: Array,
):
    graphstate, critic_results = train_critic_step(
        actions=actions,
        dones=dones,
        gamma=gamma,
        graphdef=graphdef,
        graphstate=graphstate,
        next_means=next_means,
        next_observations=next_observations,
        next_stds=next_stds,
        next_values=next_values,
        observations=observations,
        ood_threshold=ood_threshold,
        rewards=rewards,
        step=step,
        train_critic_key=train_critic_key,
    )
    train_state = nnx.merge(graphdef, graphstate)
    train_state.policy.critic.eval()
    actor_results = train_actor_step(
        actor_optimizer=train_state.actor_optimizer,
        key=train_actor_key,
        key_data=step,
        means=means_actor,
        observations=observations_actor,
        stds=stds_actor,
        policy=train_state.policy,
        threshold=ood_threshold,
    )
    train_state.target_policy.update(model=train_state.policy, tau=tau)
    train_state.policy.critic.train()
    _, graphstate = nnx.split(train_state)
    return graphstate, actor_results | critic_results
