from functools import partial

from flax import nnx
import jax
from jax import Array, numpy as jnp
from jax.lax import stop_gradient
from jax.random import fold_in, normal
from optax import squared_error

from offline.modules.actor.base import DeterministicActor
from offline.modules.critic import QCriticEnsemble
from offline.td3bc.modules import ActorFilter, TD3BCPolicy, TD3BCTrainState
from offline.types import ArrayLike, BoolArray, FloatArray


def actor_loss_fn(
    policy: TD3BCPolicy,
    actions: FloatArray,
    alpha: float,
    observations: FloatArray,
):
    policy_actions = policy.actor(observations)[0]
    bc_loss = squared_error(policy_actions, actions).mean()
    if alpha > 0:
        qvalues = jnp.min(policy.critic(observations, policy_actions), axis=0)
        lambda_ = 1 / jnp.mean(jnp.abs(stop_gradient(qvalues)))
        qvalues = jnp.mean(qvalues)
        td3_loss = -qvalues * lambda_
        loss = bc_loss + alpha * td3_loss
        return loss, {
            "loss/actor": loss,
            "loss/actor/bc": bc_loss,
            "loss/actor/td3": td3_loss,
            "loss/actor/lambda": lambda_,
            "loss/actor/Q": qvalues,
        }
    return bc_loss, {"loss/actor": bc_loss, "loss/actor/bc": bc_loss}


def critic_loss_fn(
    critic: QCriticEnsemble,
    actions: FloatArray,
    observations: FloatArray,
    targets: Array,
):
    predictions = critic(observations, actions)
    targets = jnp.broadcast_to(targets, predictions.shape)
    loss = squared_error(predictions, targets).mean()
    return loss, {
        "loss/Q": loss,
        "train/Q": predictions.mean(),
        "train/QT": targets.mean(),
    }


def compute_values(
    actor: DeterministicActor,
    critic: QCriticEnsemble,
    key: Array,
    key_data: Array | int,
    noise_clip: float,
    observations: ArrayLike,
    policy_noise: float,
) -> Array:
    actions = actor(observations)[0]
    if policy_noise > 0:
        key = fold_in(key, key_data)
        noise = normal(key, actions.shape) * policy_noise
        noise = jnp.clip(noise, -noise_clip, noise_clip)
        actions = jnp.clip(actions + noise, -1, 1)
    value = critic(observations, actions)
    return jnp.min(value, axis=0, keepdims=True)


def train_actor_step(
    actions: FloatArray,
    actor_optimizer: nnx.Optimizer,
    alpha: float,
    observations: FloatArray,
    policy: TD3BCPolicy,
) -> dict[str, Array]:
    diff_state = nnx.DiffState(0, ActorFilter)
    grad_fn = nnx.grad(actor_loss_fn, argnums=diff_state, has_aux=True)
    grads, results = grad_fn(policy, actions, alpha, observations)
    actor_optimizer.update(grads)
    return results


@partial(jax.jit, static_argnames=("policy_noise",))
def train_critic_step(
    actions: FloatArray,
    dones: BoolArray,
    gamma: float,
    graphdef: nnx.GraphDef[TD3BCTrainState],
    graphstate: nnx.GraphState | nnx.VariableState,
    next_observations: FloatArray,
    noise_clip: float,
    observations: FloatArray,
    policy_noise: float,
    rewards: FloatArray,
    step: int,
    train_key: Array,
) -> tuple[nnx.GraphState | nnx.VariableState, dict[str, Array]]:
    train_state = nnx.merge(graphdef, graphstate)
    next_values = compute_values(
        actor=train_state.target_policy.model.actor,
        critic=train_state.target_policy.model.critic,
        key=train_key,
        key_data=step,
        noise_clip=noise_clip,
        observations=next_observations,
        policy_noise=policy_noise,
    )
    targets = gamma * next_values * (1 - dones) + rewards
    grad_fn = nnx.grad(critic_loss_fn, has_aux=True)
    grads, results = grad_fn(
        train_state.policy.critic, actions, observations, targets
    )
    train_state.critic_optimizer.update(grads)
    _, graphstate = nnx.split(train_state)
    return graphstate, results


@partial(jax.jit, static_argnames=("alpha", "policy_noise"))
def train_actor_critic_step(
    actions: FloatArray,
    actions_actor: FloatArray,
    alpha: float,
    dones: BoolArray,
    gamma: float,
    graphdef: nnx.GraphDef[TD3BCTrainState],
    graphstate: nnx.GraphState | nnx.VariableState,
    next_observations: FloatArray,
    noise_clip: float,
    observations: FloatArray,
    observations_actor: FloatArray,
    policy_noise: float,
    rewards: FloatArray,
    step: int,
    tau: float,
    train_key: Array,
) -> tuple[nnx.GraphState | nnx.VariableState, dict[str, Array]]:
    graphstate, critic_results = train_critic_step(
        actions=actions,
        dones=dones,
        gamma=gamma,
        graphdef=graphdef,
        graphstate=graphstate,
        next_observations=next_observations,
        noise_clip=noise_clip,
        observations=observations,
        policy_noise=policy_noise,
        rewards=rewards,
        step=step,
        train_key=train_key,
    )
    train_state = nnx.merge(graphdef, graphstate)
    train_state.policy.critic.eval()
    actor_results = train_actor_step(
        actions=actions_actor,
        actor_optimizer=train_state.actor_optimizer,
        alpha=alpha,
        observations=observations_actor,
        policy=train_state.policy,
    )
    train_state.target_policy.update(model=train_state.policy, tau=tau)
    _, graphstate = nnx.split(train_state)
    return graphstate, actor_results | critic_results
