from functools import partial

from flax import nnx
from jax import Array, jit, lax, numpy as jnp
from jax.random import fold_in, normal
from optax import squared_error

from offline.bppo.modules import BPPOPolicy
from offline.modules.actor.base import GaussianActor
from offline.modules.actor.utils import gaussian_log_likelihood
from offline.modules.base import TrainState, TrainStateWithTarget
from offline.modules.critic import QCritic
from offline.types import BoolArray


EPS = 1e-10


def bppo_loss_fn(
    actor: GaussianActor,
    actions: Array,
    advantages: Array,
    clip_epsilon: float,
    entropy_weight: float,
    observations: Array,
    old_log_likelihoods: Array,
):
    means, stds = actor(observations)
    log_likelihoods = gaussian_log_likelihood(
        means=means, samples=actions, stds=stds
    )
    ratios = jnp.exp(log_likelihoods - old_log_likelihoods)
    clipped_ratios = lax.clamp(1 - clip_epsilon, ratios, 1 + clip_epsilon)
    clipped_fraction = jnp.mean(
        jnp.logical_or(ratios > 1 + clip_epsilon, ratios < 1 - clip_epsilon)
    )
    loss = jnp.minimum(ratios * advantages, clipped_ratios * advantages)
    loss = -jnp.mean(loss)
    if entropy_weight > 0:
        log_stds = jnp.log(stds)
        loss = loss - entropy_weight * jnp.mean(jnp.sum(log_stds, axis=-1))
    return loss, {"loss/actor": loss, "train/clipfrac": clipped_fraction}


def sarsa_loss_fn(
    qcritic: QCritic, actions: Array, observations: Array, targets: Array
):
    predictions = qcritic(observations, actions)
    errors = squared_error(predictions, targets)
    loss = jnp.mean(errors)
    return loss, {
        "loss/Q": loss,
        "train/Q": jnp.mean(predictions),
        "train/QT": jnp.mean(targets),
    }


def _sarsa_step(
    actions: Array,
    dones: BoolArray,
    gamma: float,
    next_actions: Array,
    next_observations: Array,
    observations: Array,
    rewards: Array,
    train_state: TrainStateWithTarget[QCritic],
):
    next_values = train_state.target.model(next_observations, next_actions)
    targets = next_values * gamma * (1 - dones) + rewards
    grad_fn = nnx.grad(sarsa_loss_fn, has_aux=True)
    grads, results = grad_fn(train_state.model, actions, observations, targets)
    train_state.optimizer.update(grads)
    return results


@jit
def sarsa_step(
    actions: Array,
    dones: BoolArray,
    gamma: float,
    graphdef: nnx.GraphDef[TrainStateWithTarget[QCritic]],
    graphstate: nnx.GraphState | nnx.VariableState,
    next_actions: Array,
    next_observations: Array,
    observations: Array,
    rewards: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    results = _sarsa_step(
        actions=actions,
        dones=dones,
        gamma=gamma,
        next_actions=next_actions,
        next_observations=next_observations,
        observations=observations,
        rewards=rewards,
        train_state=train_state,
    )
    _, graphstate = nnx.split(train_state)
    return graphstate, results


@jit
def sarsa_step_with_target_update(
    actions: Array,
    dones: BoolArray,
    gamma: float,
    graphdef: nnx.GraphDef[TrainStateWithTarget[QCritic]],
    graphstate: nnx.GraphState | nnx.VariableState,
    next_actions: Array,
    next_observations: Array,
    observations: Array,
    rewards: Array,
    tau: float,
):
    train_state = nnx.merge(graphdef, graphstate)
    results = _sarsa_step(
        actions=actions,
        dones=dones,
        gamma=gamma,
        next_actions=next_actions,
        next_observations=next_observations,
        observations=observations,
        rewards=rewards,
        train_state=train_state,
    )
    train_state.target.update(train_state.model, tau)
    _, graphstate = nnx.split(train_state)
    return graphstate, results


@partial(jit, static_argnames=("entropy_weight", "omega"))
def train_step(
    clip_epsilon: float,
    entropy_weight: float,
    graphdef: nnx.GraphDef[TrainState[BPPOPolicy]],
    graphstate: nnx.GraphState | nnx.VariableState,
    means: Array,
    observations: Array,
    omega: float,
    stds: Array,
    step: int,
    train_key: Array,
    values: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    noise = normal(fold_in(train_key, step), shape=stds.shape)
    actions = means + stds * noise
    qvalues = train_state.model.critic(observations, actions)
    advantages = qvalues - values
    advantages = (advantages - jnp.mean(advantages)) / (
        jnp.std(advantages) + EPS
    )
    if omega != 0.5:
        omega_array = jnp.broadcast_to(omega, advantages.shape)
        weights = lax.select(advantages > 0, omega_array, 1 - omega_array)
        advantages = weights * advantages
    old_log_likelihoods = gaussian_log_likelihood(
        means=means, samples=actions, stds=stds
    )
    grad_fn = nnx.grad(bppo_loss_fn, has_aux=True)
    grads, results = grad_fn(
        train_state.model.actor,
        actions=actions,
        advantages=advantages,
        clip_epsilon=clip_epsilon,
        entropy_weight=entropy_weight,
        observations=observations,
        old_log_likelihoods=old_log_likelihoods,
    )
    train_state.optimizer.update(grads)
    _, graphstate = nnx.split(train_state)
    return graphstate, results
