from functools import partial
from typing import Any

from flax import nnx
from jax import Array, jit, lax, numpy as jnp
from jax.random import fold_in, normal
from jax.scipy.special import erf
from optax import squared_error

from offline.lbp.modules import (
    ActorFilter,
    BehaviorState,
    LBPPolicy,
    LBPTrainState,
)
from offline.modules.actor.base import GaussianActor
from offline.modules.actor.utils import compute_actions, gaussian_log_likelihood
from offline.modules.base import TrainState, TrainStateWithTarget
from offline.modules.critic import QCriticEnsemble, VCritic
from offline.td3bc.core import critic_loss_fn as td_loss_fn
from offline.types import ArrayLike, BoolArray


def actor_loss_fn(
    policy: LBPPolicy,
    key: Array,
    key_data: Any,
    means: Array,
    observations: Array,
):
    deltas = compute_actions(
        actor=policy.actor,
        key=key,
        key_data=key_data,
        observations=observations,
        squash=False,
    )
    actions = means + deltas
    qvalues = jnp.min(policy.critic(observations, actions), axis=0)
    loss = -jnp.mean(qvalues)
    return loss


def behavior_cloning_loss_fn(
    actor: GaussianActor, actions: Array, observations: Array
):
    means, stds = actor(observations)
    log_likelihoods = gaussian_log_likelihood(
        means=means, samples=actions, stds=stds
    )
    return -jnp.mean(log_likelihoods)


def compute_mean_norms(deltas: Array, stds: Array):
    temp1 = deltas / stds / jnp.sqrt(2)
    temp2 = stds * jnp.sqrt(2 / jnp.pi) * jnp.exp(-jnp.square(temp1))
    temp3 = deltas * erf(temp1)
    results = jnp.sum(temp2 + temp3, axis=-1)
    return results


def compute_ood_mask(deltas: ArrayLike, stds: ArrayLike, threshold: float):
    temp = jnp.sum(jnp.square(deltas / stds), axis=-1)
    results = temp > threshold
    return lax.stop_gradient(results)


def critic_loss_fn(
    critic: QCriticEnsemble,
    actions: Array,
    actions_regularizer: Array,
    observations: Array,
    observations_regularizer: Array,
    ood_mask: Array,
    regularizer_weight: float,
    targets: Array,
    targets_regularizer: Array,
):
    # [2, ..., action_dim]
    actions_total = jnp.stack((actions, actions_regularizer))
    # [2, ..., observation_dim]
    observations_total = jnp.stack((observations, observations_regularizer))
    # [ensemble_size, 2, ...]
    qvalues_total = critic(observations_total, actions_total)
    # [ensemble_size, ...]
    qvalues = qvalues_total[:, 0, ...]
    qvalues_regularizer = qvalues_total[:, 1, ...]
    targets = jnp.broadcast_to(targets, qvalues.shape)
    targets_regularizer = jnp.broadcast_to(
        targets_regularizer, qvalues_regularizer.shape
    )
    regularizer_loss = jnp.mean(
        ood_mask * squared_error(qvalues_regularizer, targets_regularizer)
    )
    td_loss = jnp.mean(squared_error(qvalues, targets))
    loss = td_loss + regularizer_weight * regularizer_loss
    return loss, {
        "loss/Q": loss,
        "loss/REG": regularizer_loss,
        "loss/TD": td_loss,
        "train/Q": qvalues.mean(),
        "train/QR": qvalues_regularizer.mean(where=ood_mask),
        "train/QT": targets.mean(),
        "train/QTR": targets_regularizer.mean(where=ood_mask),
        "train/OOD": ood_mask.mean(),
    }


def v_learning_loss_fn(vcritic: VCritic, observations: Array, targets: Array):
    predictions = vcritic(observations)
    errors = squared_error(predictions, targets)
    loss = jnp.mean(errors)
    return loss, {
        "loss/V": loss,
        "train/V": jnp.mean(predictions),
        "train/VT": jnp.mean(targets),
    }


@jit
def behavior_cloning_step(
    actions: Array,
    graphdef: nnx.GraphDef[TrainState[GaussianActor]],
    graphstate: nnx.GraphState | nnx.VariableState,
    observations: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    value_and_grad_fn = nnx.value_and_grad(behavior_cloning_loss_fn)
    loss, grads = value_and_grad_fn(train_state.model, actions, observations)
    train_state.optimizer.update(grads)
    _, graphstate = nnx.split(train_state)
    return graphstate, {"loss/BC": loss}


@jit
def compute_batch_means_stds_values(
    graphdef: nnx.GraphDef[BehaviorState],
    graphstate: nnx.GraphState | nnx.VariableState,
    observations: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    means, stds = train_state.actor(observations)
    values = train_state.critic(observations)
    return means, stds, values


@jit
def v_learning_step(
    graphdef: nnx.GraphDef[TrainStateWithTarget[VCritic]],
    graphstate: nnx.GraphState | nnx.VariableState,
    observations: Array,
    targets: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    grad_fn = nnx.grad(v_learning_loss_fn, has_aux=True)
    grads, results = grad_fn(train_state.model, observations, targets)
    train_state.optimizer.update(grads)
    _, graphstate = nnx.split(train_state)
    return graphstate, results


def compute_targets(
    critic: QCriticEnsemble,
    dones: BoolArray,
    gamma: float,
    next_actions: Array,
    next_observations: Array,
    rewards: Array,
):
    targets = critic(next_observations, next_actions)
    targets = jnp.min(targets, axis=0, keepdims=True)
    targets = gamma * targets * (1 - dones) + rewards
    return targets


def _pretrain_critic_step(
    actions: Array,
    dones: BoolArray,
    gamma: float,
    next_means: Array,
    next_observations: Array,
    observations: Array,
    rewards: Array,
    train_state: TrainStateWithTarget[QCriticEnsemble],
):
    targets = compute_targets(
        critic=train_state.target.model,
        dones=dones,
        gamma=gamma,
        next_actions=next_means,
        next_observations=next_observations,
        rewards=rewards,
    )
    grad_fn = nnx.grad(td_loss_fn, has_aux=True)
    grads, results = grad_fn(train_state.model, actions, observations, targets)
    results["loss/pretrain/Q"] = results["loss/Q"]
    train_state.optimizer.update(grads)
    return {
        "loss/pretrain/Q": results["loss/Q"],
        "pretrain/Q": results["train/Q"],
        "pretrain/QT": results["train/QT"],
    }


@jit
def pretrain_critic_step(
    actions: Array,
    dones: BoolArray,
    gamma: float,
    graphdef: nnx.GraphDef[TrainStateWithTarget[QCriticEnsemble]],
    graphstate: nnx.GraphState | nnx.VariableState,
    next_means: Array,
    next_observations: Array,
    observations: Array,
    rewards: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    results = _pretrain_critic_step(
        actions=actions,
        dones=dones,
        gamma=gamma,
        next_means=next_means,
        next_observations=next_observations,
        observations=observations,
        rewards=rewards,
        train_state=train_state,
    )
    _, graphstate = nnx.split(train_state)
    return graphstate, results


@jit
def pretrain_critic_step_with_target_update(
    actions: Array,
    dones: BoolArray,
    gamma: float,
    graphdef: nnx.GraphDef[TrainStateWithTarget[QCriticEnsemble]],
    graphstate: nnx.GraphState | nnx.VariableState,
    next_means: Array,
    next_observations: Array,
    observations: Array,
    rewards: Array,
    tau: float,
):
    train_state = nnx.merge(graphdef, graphstate)
    results = _pretrain_critic_step(
        actions=actions,
        dones=dones,
        gamma=gamma,
        next_means=next_means,
        next_observations=next_observations,
        observations=observations,
        rewards=rewards,
        train_state=train_state,
    )
    train_state.target.update(model=train_state.model, tau=tau)
    _, graphstate = nnx.split(train_state)
    return graphstate, results


def train_actor_step(
    actor_optimizer: nnx.Optimizer,
    key: Array,
    key_data: Any,
    means: Array,
    observations: Array,
    policy: LBPPolicy,
) -> dict[str, Array]:
    diff_state = nnx.DiffState(0, ActorFilter)
    value_and_grad_fn = nnx.value_and_grad(actor_loss_fn, argnums=diff_state)
    loss, grads = value_and_grad_fn(policy, key, key_data, means, observations)
    actor_optimizer.update(grads)
    return {"loss/actor": loss}


@partial(jit, static_argnames=("noise_batch_size", "regularizer_weight"))
def train_critic_step(
    actions: Array,
    baseline_regularizer: Array,
    dones: BoolArray,
    gamma: float,
    graphdef: nnx.GraphDef[LBPTrainState],
    graphstate: nnx.GraphState | nnx.VariableState,
    lipschitz_constant: float,
    means_regularizer: Array,
    min_target: float,
    next_means: Array,
    next_observations: Array,
    noise_batch_size: int,
    observations: Array,
    observations_regularizer: Array,
    ood_threshold: float,
    regularizer_weight: float,
    rewards: Array,
    std_multiplier: float,
    stds_regularizer: Array,
    step: int,
    train_critic_key: Array,
    train_regularizer_key: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    next_deltas = compute_actions(
        actor=train_state.target_policy.model.actor,
        key=train_critic_key,
        key_data=step,
        observations=next_observations,
        squash=False,
    )
    targets = compute_targets(
        critic=train_state.target_policy.model.critic,
        dones=dones,
        gamma=gamma,
        next_actions=next_deltas + next_means,
        next_observations=next_observations,
        rewards=rewards,
    )
    if regularizer_weight > 0:
        white_noise = normal(
            fold_in(train_regularizer_key, step), stds_regularizer.shape
        )
        train_state.policy.eval()
        means_policy, stds_policy = train_state.policy.actor(
            observations_regularizer[..., noise_batch_size:, :]
        )
        train_state.policy.train()
        deltas_policy = (
            means_policy + stds_policy * white_noise[..., noise_batch_size:, :]
        )
        stds_noise = stds_regularizer[..., :noise_batch_size, :]
        deltas_noise = (
            stds_noise * std_multiplier * white_noise[..., :noise_batch_size, :]
        )
        deltas_regularizer = jnp.concatenate(
            (deltas_noise, deltas_policy), axis=-2
        )
        mean_norms = compute_mean_norms(
            deltas=deltas_regularizer, stds=stds_regularizer
        )
        targets_regularizer = (
            baseline_regularizer - gamma * lipschitz_constant * mean_norms
        )
        targets_regularizer = jnp.clip(targets_regularizer, min=min_target)
        targets_regularizer = jnp.expand_dims(targets_regularizer, 0)
        ood_mask = compute_ood_mask(
            deltas=deltas_regularizer,
            stds=stds_regularizer,
            threshold=ood_threshold,
        )
        grad_fn = nnx.grad(critic_loss_fn, has_aux=True)
        grads, results = grad_fn(
            train_state.policy.critic,
            actions=actions,
            actions_regularizer=deltas_regularizer + means_regularizer,
            observations=observations,
            observations_regularizer=observations_regularizer,
            ood_mask=ood_mask,
            regularizer_weight=regularizer_weight,
            targets=targets,
            targets_regularizer=targets_regularizer,
        )
    else:
        grad_fn = nnx.grad(td_loss_fn, has_aux=True)
        grads, results = grad_fn(
            train_state.policy.critic, actions, observations, targets
        )
    train_state.critic_optimizer.update(grads)
    _, graphstate = nnx.split(train_state)
    return graphstate, results


@partial(jit, static_argnames=("noise_batch_size", "regularizer_weight"))
def train_actor_critic_step(
    actions: Array,
    baseline_regularizer: Array,
    dones: BoolArray,
    gamma: float,
    graphdef: nnx.GraphDef[LBPTrainState],
    graphstate: nnx.GraphState | nnx.VariableState,
    lipschitz_constant: float,
    means_actor: Array,
    means_regularizer: Array,
    min_target: float,
    next_means: Array,
    next_observations: Array,
    noise_batch_size: int,
    observations: Array,
    observations_actor: Array,
    observations_regularizer: Array,
    ood_threshold: float,
    regularizer_weight: float,
    rewards: Array,
    std_multiplier: float,
    stds_regularizer: Array,
    step: int,
    tau: float,
    train_actor_key: Array,
    train_critic_key: Array,
    train_regularizer_key: Array,
):
    graphstate, critic_results = train_critic_step(
        actions=actions,
        baseline_regularizer=baseline_regularizer,
        dones=dones,
        gamma=gamma,
        graphdef=graphdef,
        graphstate=graphstate,
        lipschitz_constant=lipschitz_constant,
        means_regularizer=means_regularizer,
        min_target=min_target,
        next_means=next_means,
        next_observations=next_observations,
        noise_batch_size=noise_batch_size,
        observations=observations,
        observations_regularizer=observations_regularizer,
        ood_threshold=ood_threshold,
        regularizer_weight=regularizer_weight,
        rewards=rewards,
        std_multiplier=std_multiplier,
        stds_regularizer=stds_regularizer,
        step=step,
        train_critic_key=train_critic_key,
        train_regularizer_key=train_regularizer_key,
    )
    train_state = nnx.merge(graphdef, graphstate)
    train_state.policy.critic.eval()
    actor_results = train_actor_step(
        actor_optimizer=train_state.actor_optimizer,
        key=train_actor_key,
        key_data=step,
        means=means_actor,
        observations=observations_actor,
        policy=train_state.policy,
    )
    train_state.target_policy.update(model=train_state.policy, tau=tau)
    _, graphstate = nnx.split(train_state)
    return graphstate, actor_results | critic_results
