from functools import partial

from flax import nnx
from jax import Array, jit, lax, numpy as jnp
from jax.random import fold_in, normal
from optax import squared_error

from offline.bppo.core import EPS
from offline.bppo.tc.modules import (
    BPPOTCPolicy,
    BehaviorState,
    QCritic,
    select_values,
)
from offline.modules.actor.ensemble import GaussianActorEnsembleWithIndices
from offline.modules.actor.utils import gaussian_log_likelihood
from offline.modules.base import TrainState, TrainStateWithTarget
from offline.modules.mlp import MLP
from offline.types import BoolArray


def behavior_cloning_loss_fn(
    actor: GaussianActorEnsembleWithIndices,
    actions: Array,
    assignments: Array,
    observations: Array,
):
    means, stds = actor(assignments, observations)
    log_likelihoods = gaussian_log_likelihood(
        means=means, samples=actions, stds=stds
    )
    return -jnp.mean(log_likelihoods)


def bppo_loss_fn(
    actor: GaussianActorEnsembleWithIndices,
    actions: Array,
    advantages: Array,
    assignments: Array,
    clip_epsilon: float,
    entropy_weight: float,
    observations: Array,
    old_log_likelihoods: Array,
):
    # [..., num_embeddings, action_dim]
    means, stds = actor(assignments, observations)
    # [...]
    log_likelihoods = gaussian_log_likelihood(
        means=means, samples=actions, stds=stds
    )
    ratios = jnp.exp(log_likelihoods - old_log_likelihoods)
    clipped_ratios = lax.clamp(1 - clip_epsilon, ratios, 1 + clip_epsilon)
    clipped_fraction = jnp.mean(
        jnp.logical_or(ratios > 1 + clip_epsilon, ratios < 1 - clip_epsilon)
    )
    loss = jnp.minimum(ratios * advantages, clipped_ratios * advantages)
    loss = -jnp.mean(loss)
    if entropy_weight > 0:
        log_stds = jnp.log(stds)
        loss = loss - entropy_weight * jnp.mean(jnp.sum(log_stds, axis=-1))
    return loss, {"loss/actor": loss, "train/clipfrac": clipped_fraction}


def high_level_loss_fn(
    qcritic: MLP, assignments: Array, observations: Array, targets: Array
):
    # [..., codebook_size]
    predictions = qcritic(observations)
    # [..., 1]
    predictions = jnp.squeeze(
        jnp.take_along_axis(
            predictions, jnp.expand_dims(assignments, -1), axis=-1
        ),
        axis=-1,
    )
    errors = squared_error(predictions, targets)
    loss = jnp.mean(errors)
    return loss, {
        "loss/HLQ": loss,
        "pretrain/HLQ": jnp.mean(predictions),
        "pretrain/HLQT": jnp.mean(targets),
    }


def sarsa_loss_fn(
    qcritic: QCritic,
    actions: Array,
    assignments: Array,
    observations: Array,
    targets: Array,
):
    # [..., num_embeddings]
    predictions = qcritic(assignments, observations, actions)
    errors = squared_error(predictions, targets)
    loss = jnp.mean(errors)
    return loss, {
        "loss/Q": loss,
        "pretrain/Q": jnp.mean(predictions),
        "pretrain/QT": jnp.mean(targets),
    }


@jit
def behavior_cloning_step(
    actions: Array,
    assignments: Array,
    graphdef: nnx.GraphDef[TrainState[GaussianActorEnsembleWithIndices]],
    graphstate: nnx.GraphState | nnx.VariableState,
    observations: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    value_and_grad_fn = nnx.value_and_grad(behavior_cloning_loss_fn)
    loss, grads = value_and_grad_fn(
        train_state.model, actions, assignments, observations
    )
    train_state.optimizer.update(grads)
    _, graphstate = nnx.split(train_state)
    return graphstate, {"loss/BC": loss}


@jit
def compute_batch_means_stds_values(
    assignments: Array,
    graphdef: nnx.GraphDef[BehaviorState],
    graphstate: nnx.GraphState | nnx.VariableState,
    observations: Array,
):
    # observations: [..., observation_dim]
    train_state = nnx.merge(graphdef, graphstate)
    means, stds = train_state.actor(assignments, observations)
    values = train_state.critic(observations)
    assignments = jnp.expand_dims(assignments, -1)
    values = jnp.squeeze(
        jnp.take_along_axis(values, assignments, axis=-1), axis=-1
    )
    return means, stds, values


def _high_level_train_step(
    assignments: Array,
    dones: BoolArray,
    gamma: float,
    next_mask: Array,
    next_observations: Array,
    observations: Array,
    rewards: Array,
    train_state: TrainStateWithTarget[MLP],
):
    # [..., codebook_size]
    targets = train_state.target.model(next_observations)
    targets = select_values(next_mask, targets)
    # [...]
    targets = jnp.max(targets, axis=-1)
    targets = targets * gamma * (1 - dones) + rewards
    grad_fn = nnx.grad(high_level_loss_fn, has_aux=True)
    grads, results = grad_fn(
        train_state.model,
        assignments=assignments,
        observations=observations,
        targets=targets,
    )
    train_state.optimizer.update(grads)
    return results


@jit
def high_level_train_step(
    assignments: Array,
    dones: BoolArray,
    gamma: float,
    graphdef: nnx.GraphDef[TrainStateWithTarget[MLP]],
    graphstate: nnx.GraphState | nnx.VariableState,
    next_mask: Array,
    next_observations: Array,
    observations: Array,
    rewards: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    results = _high_level_train_step(
        assignments=assignments,
        dones=dones,
        gamma=gamma,
        next_mask=next_mask,
        next_observations=next_observations,
        observations=observations,
        rewards=rewards,
        train_state=train_state,
    )
    _, graphstate = nnx.split(train_state)
    return graphstate, results


@jit
def high_level_train_step_with_target_update(
    assignments: Array,
    dones: BoolArray,
    gamma: float,
    graphdef: nnx.GraphDef[TrainStateWithTarget[MLP]],
    graphstate: nnx.GraphState | nnx.VariableState,
    next_mask: Array,
    next_observations: Array,
    observations: Array,
    rewards: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    results = _high_level_train_step(
        assignments=assignments,
        dones=dones,
        gamma=gamma,
        next_mask=next_mask,
        next_observations=next_observations,
        observations=observations,
        rewards=rewards,
        train_state=train_state,
    )
    train_state.target.hard_update(train_state.model)
    _, graphstate = nnx.split(train_state)
    return graphstate, results


def _sarsa_step(
    actions: Array,
    assignments: Array,
    dones: Array,
    gamma: float,
    next_actions: Array,
    next_observations: Array,
    observations: Array,
    rewards: Array,
    train_state: TrainStateWithTarget[QCritic],
):
    next_values = train_state.target.model(
        assignments, next_observations, next_actions
    )
    targets = next_values * gamma * (1 - dones) + rewards
    grad_fn = nnx.grad(sarsa_loss_fn, has_aux=True)
    grads, results = grad_fn(
        train_state.model, actions, assignments, observations, targets
    )
    train_state.optimizer.update(grads)
    return results


@jit
def sarsa_step(
    actions: Array,
    assignments: Array,
    dones: Array,
    gamma: float,
    graphdef: nnx.GraphDef[TrainStateWithTarget[QCritic]],
    graphstate: nnx.GraphState | nnx.VariableState,
    next_actions: Array,
    next_observations: Array,
    observations: Array,
    rewards: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    results = _sarsa_step(
        actions=actions,
        assignments=assignments,
        dones=dones,
        gamma=gamma,
        next_actions=next_actions,
        next_observations=next_observations,
        observations=observations,
        rewards=rewards,
        train_state=train_state,
    )
    _, graphstate = nnx.split(train_state)
    return graphstate, results


@jit
def sarsa_step_with_target_update(
    actions: Array,
    assignments: Array,
    dones: Array,
    gamma: float,
    graphdef: nnx.GraphDef[TrainStateWithTarget[QCritic]],
    graphstate: nnx.GraphState | nnx.VariableState,
    next_actions: Array,
    next_observations: Array,
    observations: Array,
    rewards: Array,
    tau: float,
):
    train_state = nnx.merge(graphdef, graphstate)
    results = _sarsa_step(
        actions=actions,
        assignments=assignments,
        dones=dones,
        gamma=gamma,
        next_actions=next_actions,
        next_observations=next_observations,
        observations=observations,
        rewards=rewards,
        train_state=train_state,
    )
    train_state.target.update(train_state.model, tau)
    _, graphstate = nnx.split(train_state)
    return graphstate, results


@partial(jit, static_argnames=("entropy_weight", "omega"))
def train_step(
    assignments: Array,
    clip_epsilon: float,
    entropy_weight: float,
    graphdef: nnx.GraphDef[TrainState[BPPOTCPolicy]],
    graphstate: nnx.GraphState | nnx.VariableState,
    means: Array,
    observations: Array,
    omega: float,
    stds: Array,
    step: int,
    train_key: Array,
    values: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    noise = normal(fold_in(train_key, step), shape=stds.shape)
    actions = means + stds * noise
    qvalues = train_state.model.critic(assignments, observations, actions)
    advantages = qvalues - values
    advantages = (advantages - jnp.mean(advantages)) / (
        jnp.std(advantages) + EPS
    )
    if omega != 0.5:
        omega_array = jnp.broadcast_to(omega, advantages.shape)
        weights = lax.select(advantages > 0, omega_array, 1 - omega_array)
        advantages = weights * advantages
    old_log_likelihoods = gaussian_log_likelihood(
        means=means, samples=actions, stds=stds
    )
    grad_fn = nnx.grad(bppo_loss_fn, has_aux=True)
    grads, results = grad_fn(
        train_state.model.actor,
        actions=actions,
        advantages=advantages,
        assignments=assignments,
        clip_epsilon=clip_epsilon,
        entropy_weight=entropy_weight,
        observations=observations,
        old_log_likelihoods=old_log_likelihoods,
    )
    train_state.optimizer.update(grads)
    _, graphstate = nnx.split(train_state)
    return graphstate, results
