from functools import partial

from flax import nnx
from jax import Array, jit, lax, numpy as jnp
from jax.random import fold_in, normal
from optax import squared_error

from offline.svr.modules import ActorFilter, SVRPolicy, SVRTrainState
from offline.modules.actor.base import DeterministicActor
from offline.modules.actor.utils import gaussian_log_likelihood
from offline.modules.base import TrainState
from offline.modules.critic import QCriticEnsemble
from offline.td3bc.core import critic_loss_fn as td_loss_fn
from offline.types import BoolArray


def actor_loss_fn(
    policy: SVRPolicy,
    alpha: float,
    observations: Array,
):
    actions, _ = policy.actor(observations)
    qvalues = jnp.min(policy.critic(observations, actions), axis=0)
    lambda_ = 1 / jnp.mean(jnp.abs(lax.stop_gradient(qvalues)))
    qvalues = jnp.mean(qvalues)
    loss = -alpha * qvalues * lambda_
    return loss, {
        "loss/actor": loss,
        "loss/actor/lambda": lambda_,
        "loss/actor/Q": qvalues,
    }


def behavior_cloning_loss_fn(
    actor: DeterministicActor, actions: Array, observations: Array
):
    policy_actions, _ = actor(observations)
    loss = squared_error(policy_actions, actions)
    return jnp.mean(loss)


def critic_loss_fn(
    critic: QCriticEnsemble,
    actions: Array,
    noisy_actions: Array,
    observations: Array,
    regularizer_weight: float,
    targets: Array,
    targets_regularizer: float,
    weights: Array,
):
    # [ensemble_size, ...]
    qvalues = critic(observations, actions)
    qvalues_noise = critic(observations, noisy_actions)
    targets_regularizer_ = jnp.full_like(
        qvalues_noise, fill_value=targets_regularizer
    )
    targets = jnp.broadcast_to(targets, qvalues.shape)
    td_loss = jnp.mean(jnp.sum(squared_error(qvalues, targets), axis=0))
    behavior_loss = squared_error(qvalues, targets_regularizer_)
    noise_loss = squared_error(qvalues_noise, targets_regularizer_)
    regularizer_loss = jnp.mean(
        jnp.clip(noise_loss - weights * behavior_loss, min=-1e4)
    )
    loss = td_loss + regularizer_weight * regularizer_loss
    return loss, {
        "loss/Q": loss,
        "loss/behavior": behavior_loss.mean(),
        "loss/noise": noise_loss.mean(),
        "loss/REG": regularizer_loss,
        "loss/TD": td_loss,
        "train/Q": qvalues.mean(),
        "train/QN": qvalues_noise.mean(),
        "train/QT": targets.mean(),
        "train/weights": weights.mean(),
    }


@jit
def behavior_cloning_step(
    actions: Array,
    graphdef: nnx.GraphDef[TrainState[DeterministicActor]],
    graphstate: nnx.GraphState | nnx.VariableState,
    observations: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    value_and_grad_fn = nnx.value_and_grad(behavior_cloning_loss_fn)
    loss, grads = value_and_grad_fn(train_state.model, actions, observations)
    train_state.optimizer.update(grads)
    _, graphstate = nnx.split(train_state)
    return graphstate, {"loss/BC": loss}


@jit
def compute_batch_log_likelihoods(
    actions: Array,
    graphdef: nnx.GraphDef[DeterministicActor],
    graphstate: nnx.GraphState | nnx.VariableState,
    observations: Array,
    sample_std: float,
):
    actor = nnx.merge(graphdef, graphstate)
    means, _ = actor(observations)
    log_likelihoods = gaussian_log_likelihood(
        means=means, samples=actions, stds=sample_std
    )
    return log_likelihoods


def compute_targets(
    critic: QCriticEnsemble,
    dones: BoolArray,
    gamma: float,
    next_actions: Array,
    next_observations: Array,
    rewards: Array,
):
    targets = critic(next_observations, next_actions)
    targets = jnp.min(targets, axis=0, keepdims=True)
    targets = gamma * targets * (1 - dones) + rewards
    return targets


def compute_weights(
    actions: Array, behavior_log_likelihoods: Array, means: Array, std: float
):
    policy_log_likelihoods = gaussian_log_likelihood(
        means=means, samples=actions, stds=std
    )
    weights = jnp.exp(policy_log_likelihoods - behavior_log_likelihoods)
    return jnp.clip(weights, min=0.1)


def train_actor_step(
    actor_optimizer: nnx.Optimizer,
    alpha: float,
    observations: Array,
    policy: SVRPolicy,
) -> dict[str, Array]:
    diff_state = nnx.DiffState(0, ActorFilter)
    grad_fn = nnx.grad(actor_loss_fn, argnums=diff_state, has_aux=True)
    grads, results = grad_fn(policy, alpha=alpha, observations=observations)
    actor_optimizer.update(grads)
    return results


@partial(jit, static_argnames=("regularizer_weight",))
def train_critic_step(
    actions: Array,
    dones: BoolArray,
    gamma: float,
    graphdef: nnx.GraphDef[SVRTrainState],
    graphstate: nnx.GraphState | nnx.VariableState,
    log_likelihoods: Array,
    next_observations: Array,
    observations: Array,
    regularizer_weight: float,
    rewards: Array,
    sample_std: float,
    step: int,
    targets_regularizer: float,
    train_key: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    target_policy = train_state.target_policy.model
    next_actions, _ = target_policy.actor(observations)
    targets = compute_targets(
        critic=target_policy.critic,
        dones=dones,
        gamma=gamma,
        next_actions=next_actions,
        next_observations=next_observations,
        rewards=rewards,
    )
    if regularizer_weight > 0:
        train_state.policy.eval()
        means, _ = train_state.policy.actor(observations)
        train_state.policy.train()
        white_noise = normal(fold_in(train_key, step), means.shape)
        noisy_actions = jnp.clip(
            means + sample_std * white_noise, max=1, min=-1
        )
        weights = compute_weights(
            actions=actions,
            behavior_log_likelihoods=log_likelihoods,
            means=means,
            std=sample_std,
        )
        grad_fn = nnx.grad(critic_loss_fn, has_aux=True)
        grads, results = grad_fn(
            train_state.policy.critic,
            actions=actions,
            noisy_actions=noisy_actions,
            observations=observations,
            regularizer_weight=regularizer_weight,
            targets=targets,
            targets_regularizer=targets_regularizer,
            weights=weights,
        )
    else:
        grad_fn = nnx.grad(td_loss_fn, has_aux=True)
        grads, results = grad_fn(
            train_state.policy.critic, actions, observations, targets
        )
    train_state.critic_optimizer.update(grads)
    _, graphstate = nnx.split(train_state)
    return graphstate, results


@partial(jit, static_argnames=("regularizer_weight",))
def train_actor_critic_step(
    actions: Array,
    alpha: float,
    dones: BoolArray,
    gamma: float,
    graphdef: nnx.GraphDef[SVRTrainState],
    graphstate: nnx.GraphState | nnx.VariableState,
    log_likelihoods: Array,
    next_observations: Array,
    observations: Array,
    regularizer_weight: float,
    rewards: Array,
    sample_std: float,
    step: int,
    targets_regularizer: float,
    tau: float,
    train_key: Array,
):
    graphstate, critic_results = train_critic_step(
        actions=actions,
        dones=dones,
        gamma=gamma,
        graphdef=graphdef,
        graphstate=graphstate,
        log_likelihoods=log_likelihoods,
        next_observations=next_observations,
        observations=observations,
        regularizer_weight=regularizer_weight,
        rewards=rewards,
        sample_std=sample_std,
        step=step,
        targets_regularizer=targets_regularizer,
        train_key=train_key,
    )
    train_state = nnx.merge(graphdef, graphstate)
    train_state.policy.critic.eval()
    actor_results = train_actor_step(
        actor_optimizer=train_state.actor_optimizer,
        alpha=alpha,
        observations=observations,
        policy=train_state.policy,
    )
    train_state.target_policy.update(model=train_state.policy, tau=tau)
    _, graphstate = nnx.split(train_state)
    return graphstate, actor_results | critic_results
