from functools import partial
from typing import Any

from flax import nnx
from jax import Array, jit, numpy as jnp
from jax.random import fold_in, normal
from optax import softmax_cross_entropy_with_integer_labels, squared_error

from offline.lbp.core import compute_mean_norms, compute_ood_mask
from offline.lbp.tc.modules import (
    ActorFilter,
    BehaviorState,
    LBPTCPolicy,
    LBPTCTrainState,
    QCriticEnsemble,
    mask_values,
)
from offline.lbp.tc.modules.auto_encoder import TCAutoEncoder
from offline.modules.actor.ensemble import GaussianActorEnsemble
from offline.modules.actor.utils import compute_actions, gaussian_log_likelihood
from offline.modules.base import TrainState, TrainStateWithTarget
from offline.modules.mlp import MLP


def actor_loss_fn(
    policy: LBPTCPolicy,
    key: Array,
    key_data: Any,
    mask: Array,
    means: Array,
    observations: Array,
):
    # [..., action_dim]
    deltas = compute_actions(
        actor=policy.actor,
        key=key,
        key_data=key_data,
        observations=observations,
        squash=False,
    )
    # [..., num_behaviors]
    _, _, qvalues = policy.compute_candidates_values(
        deltas=deltas, mask=mask, means=means, observations=observations
    )
    loss = -jnp.mean(qvalues * mask)
    return loss


def behavior_cloning_loss_fn(
    actor: GaussianActorEnsemble,
    actions: Array,
    assignments: Array,
    observations: Array,
):
    # [..., 1, 1]
    assignments = jnp.expand_dims(assignments, (-1, -2))
    # [..., num_behaviors, action_dim]
    means, stds = actor(observations)
    # [..., action_dim]
    means = jnp.squeeze(
        jnp.take_along_axis(means, assignments, axis=-2), axis=-2
    )
    stds = jnp.squeeze(jnp.take_along_axis(stds, assignments, axis=-2), axis=-2)
    log_likelihoods = gaussian_log_likelihood(
        means=means, samples=actions, stds=stds
    )
    return -jnp.mean(log_likelihoods)


def classifier_loss_fn(
    classifier: MLP, assignments: Array, observations: Array
):
    logits = classifier(observations)
    loss = jnp.mean(
        softmax_cross_entropy_with_integer_labels(logits, assignments)
    )
    return loss


def critic_loss_fn(
    critic: QCriticEnsemble,
    actions: Array,
    actions_regularizer: Array,
    mask: Array,
    observations: Array,
    observations_regularizer: Array,
    ood_mask: Array,
    regularizer_weight: float,
    targets: Array,
    targets_regularizer: Array,
):
    # [2, ..., action_dim]
    actions_total = jnp.stack((actions, actions_regularizer))
    # [2, ..., observation_dim]
    observations_total = jnp.stack((observations, observations_regularizer))
    # [ensemble_size, 2, ..., num_behaviors]
    qvalues_total = critic(observations_total, actions_total)
    # [ensemble_size, ..., num_behaviors]
    qvalues = qvalues_total[:, 0, ...]
    # [1, ..., 1]
    targets = jnp.expand_dims(targets, (0, -1))
    # [ensemble_size, ..., num_behaviors]
    targets = jnp.broadcast_to(targets, qvalues.shape)
    td_loss = jnp.mean(squared_error(qvalues, targets) * mask)
    # [ensemble_size, ..., num_behaviors]
    qvalues_regularizer = qvalues_total[:, 1, ...]
    targets_regularizer = jnp.broadcast_to(
        targets_regularizer, qvalues_regularizer.shape
    )
    regularizer_loss = jnp.mean(
        ood_mask * squared_error(qvalues_regularizer, targets_regularizer)
    )
    loss = td_loss + regularizer_weight * regularizer_loss
    return loss, {
        "loss/Q": loss,
        "loss/REG": regularizer_loss,
        "loss/TD": td_loss,
        "train/Q": qvalues.mean(where=mask),
        "train/QR": qvalues_regularizer.mean(where=ood_mask),
        "train/QT": targets.mean(where=mask),
        "train/QTR": targets_regularizer.mean(where=ood_mask),
        "train/OOD": ood_mask.mean(),
    }


def tcae_loss_fn(
    tcae: TCAutoEncoder,
    actions: Array,
    commitment_cost: float,
    decode_indices: Array,
    decode_mask: Array,
    latent_indices: Array,
    observations: Array,
    reward_weight: float,
    rewards: Array,
    transition_weight: float,
):
    results = tcae(
        actions=actions,
        decode_indices=decode_indices,
        decode_mask=decode_mask,
        latent_indices=latent_indices,
        observations=observations,
        rewards=rewards,
    )
    loss = (
        results.loss_action
        + commitment_cost * results.loss_latent
        + reward_weight * results.loss_reward
        + transition_weight * results.loss_transition
    )
    return loss, {
        "loss/TCAE": loss,
        "loss/TCAE/action": results.loss_action,
        "loss/TCAE/latent": results.loss_latent,
        "loss/TCAE/reward": results.loss_reward,
        "loss/TCAE/transition": results.loss_transition,
        "train/TCAE/perplexity": results.perplexity,
    }


def td_loss_fn(
    critic: QCriticEnsemble,
    actions: Array,
    mask: Array,
    observations: Array,
    targets: Array,
):
    # [1, ..., num_behaviors]
    mask = jnp.expand_dims(mask, axis=0)
    # [ensemble_size, ..., num_behaviors]
    predictions = critic(observations, actions)
    # [1, ..., 1]
    targets = jnp.expand_dims(targets, (0, -1))
    targets = jnp.broadcast_to(targets, predictions.shape)
    loss = jnp.mean(squared_error(predictions, targets) * mask)
    return loss, {
        "loss/Q": loss,
        "train/Q": predictions.mean(where=mask),
        "train/QT": targets.mean(where=mask),
    }


def v_learning_loss_fn(
    vcritic: MLP, assignments: Array, observations: Array, targets: Array
):
    # [..., 1]
    assignments = jnp.expand_dims(assignments, -1)
    # [..., codebook_size]
    predictions = vcritic(observations)
    # [..., 1]
    predictions = jnp.take_along_axis(predictions, assignments, axis=-1)
    # [...]
    predictions = jnp.squeeze(predictions, -1)
    errors = squared_error(predictions, targets)
    loss = jnp.mean(errors)
    return loss, {
        "loss/V": loss,
        "pretrain/V": jnp.mean(predictions),
        "pretrain/VT": jnp.mean(targets),
    }


@jit
def behavior_cloning_step(
    actions: Array,
    assignments: Array,
    graphdef: nnx.GraphDef[TrainState[GaussianActorEnsemble]],
    graphstate: nnx.GraphState | nnx.VariableState,
    observations: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    value_and_grad_fn = nnx.value_and_grad(behavior_cloning_loss_fn)
    loss, grads = value_and_grad_fn(
        train_state.model, actions, assignments, observations
    )
    train_state.optimizer.update(grads)
    _, graphstate = nnx.split(train_state)
    return graphstate, {"loss/BC": loss}


@jit
def compute_batch_embeddings(
    graphdef: nnx.GraphDef[TCAutoEncoder],
    graphstate: nnx.GraphState | nnx.VariableState,
    lengths: Array,
    observations: Array,
    rewards: Array,
):
    tcae = nnx.merge(graphdef, graphstate)
    return tcae.compute_embeddings(
        latent_indices=jnp.expand_dims(lengths, -1) - 1,
        observations=observations,
        rewards=rewards,
    )


@jit
def compute_batch_mask_means_stds_values(
    graphdef: nnx.GraphDef[BehaviorState],
    graphstate: nnx.GraphState | nnx.VariableState,
    log_threshold: float,
    observations: Array,
):
    # observations: [..., observation_dim]
    train_state = nnx.merge(graphdef, graphstate)
    values = train_state.critic(observations)
    logits = train_state.classifier(observations)
    mask = logits >= jnp.max(logits, axis=-1, keepdims=True) - log_threshold
    means, stds = train_state.actor(observations)
    # mask: [..., num_behaviors]
    # means: [..., num_behaviors, action_dim]
    # stds: [..., num_behaviors, action_dim]
    # values: [..., num_behaviors]
    return mask, means, stds, values


def compute_targets(
    critic: QCriticEnsemble,
    dones: Array,
    gamma: float,
    next_actions: Array,
    next_mask: Array,
    next_observations: Array,
    rewards: Array,
):
    # [..., num_behaviors, observation_dim]
    next_observations = jnp.broadcast_to(
        jnp.expand_dims(next_observations, -2),
        next_actions.shape[:-1] + (next_observations.shape[-1],),
    )
    # [ensemble_size, ..., num_behaviors, num_behaviors]
    targets = critic(next_observations, next_actions)
    # [..., num_behaviors, num_behaviors]
    targets = jnp.min(targets, axis=0)
    # [..., num_behaviors]
    targets = jnp.diagonal(targets, axis1=-2, axis2=-1)
    # [..., 1]
    dones = jnp.expand_dims(dones, -1)
    rewards = jnp.expand_dims(rewards, -1)
    # [..., num_behaviors]
    targets = gamma * targets * (1 - dones) + rewards
    targets = mask_values(next_mask, targets)
    # [...]
    targets = jnp.max(targets, axis=-1)
    return targets


def _pretrain_critic_step(
    actions: Array,
    dones: Array,
    gamma: float,
    mask: Array,
    next_mask: Array,
    next_means: Array,
    next_observations: Array,
    observations: Array,
    rewards: Array,
    train_state: TrainStateWithTarget[QCriticEnsemble],
):
    targets = compute_targets(
        critic=train_state.target.model,
        dones=dones,
        gamma=gamma,
        next_actions=next_means,
        next_mask=next_mask,
        next_observations=next_observations,
        rewards=rewards,
    )
    grad_fn = nnx.grad(td_loss_fn, has_aux=True)
    grads, results = grad_fn(
        train_state.model,
        actions=actions,
        mask=mask,
        observations=observations,
        targets=targets,
    )
    results["loss/pretrain/Q"] = results["loss/Q"]
    train_state.optimizer.update(grads)
    return {
        "loss/pretrain/Q": results["loss/Q"],
        "pretrain/Q": results["train/Q"],
        "pretrain/QT": results["train/QT"],
    }


@jit
def pretrain_critic_step(
    actions: Array,
    dones: Array,
    gamma: float,
    graphdef: nnx.GraphDef[TrainStateWithTarget[QCriticEnsemble]],
    graphstate: nnx.GraphState | nnx.VariableState,
    mask: Array,
    next_mask: Array,
    next_means: Array,
    next_observations: Array,
    observations: Array,
    rewards: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    results = _pretrain_critic_step(
        actions=actions,
        dones=dones,
        gamma=gamma,
        mask=mask,
        next_mask=next_mask,
        next_means=next_means,
        next_observations=next_observations,
        observations=observations,
        rewards=rewards,
        train_state=train_state,
    )
    _, graphstate = nnx.split(train_state)
    return graphstate, results


@jit
def pretrain_critic_step_with_target_update(
    actions: Array,
    dones: Array,
    gamma: float,
    graphdef: nnx.GraphDef[TrainStateWithTarget[QCriticEnsemble]],
    graphstate: nnx.GraphState | nnx.VariableState,
    mask: Array,
    next_mask: Array,
    next_means: Array,
    next_observations: Array,
    observations: Array,
    rewards: Array,
    tau: float,
):
    train_state = nnx.merge(graphdef, graphstate)
    results = _pretrain_critic_step(
        actions=actions,
        dones=dones,
        gamma=gamma,
        mask=mask,
        next_mask=next_mask,
        next_means=next_means,
        next_observations=next_observations,
        observations=observations,
        rewards=rewards,
        train_state=train_state,
    )
    train_state.target.update(model=train_state.model, tau=tau)
    _, graphstate = nnx.split(train_state)
    return graphstate, results


def train_actor_step(
    actor_optimizer: nnx.Optimizer,
    key: Array,
    key_data: Any,
    mask: Array,
    means: Array,
    observations: Array,
    policy: LBPTCPolicy,
) -> dict[str, Array]:
    diff_state = nnx.DiffState(0, ActorFilter)
    value_and_grad_fn = nnx.value_and_grad(actor_loss_fn, argnums=diff_state)
    loss, grads = value_and_grad_fn(
        policy, key, key_data, mask, means, observations
    )
    actor_optimizer.update(grads)
    return {"loss/actor": loss}


@partial(jit, static_argnames=("noise_batch_size", "regularizer_weight"))
def train_critic_step(
    actions: Array,
    assignments_regularizer: Array,
    baseline_regularizer: Array,
    dones: Array,
    gamma: float,
    graphdef: nnx.GraphDef[LBPTCTrainState],
    graphstate: nnx.GraphState | nnx.VariableState,
    lipschitz_constant: float,
    mask_regularizer: Array,
    means_regularizer: Array,
    min_target: float,
    mask: Array,
    next_mask: Array,
    next_means: Array,
    next_observations: Array,
    noise_batch_size: int,
    observations: Array,
    observations_regularizer: Array,
    ood_threshold: float,
    regularizer_weight: float,
    rewards: Array,
    std_multiplier: float,
    stds_regularizer: Array,
    step: int,
    train_critic_key: Array,
    train_regularizer_key_noise: Array,
    train_regularizer_key_policy: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    # [..., num_behaviors, action_dim]
    next_deltas = compute_actions(
        actor=train_state.target_policy.model.actor,
        key=train_critic_key,
        key_data=step,
        observations=next_observations,
        squash=False,
    )
    # [..., num_behaviors]
    _, targets, _ = train_state.target_policy.model.compute_candidates_values(
        deltas=next_deltas,
        mask=next_mask,
        means=next_means,
        observations=next_observations,
    )
    # [...]
    targets = rewards + gamma * (1 - dones) * jnp.max(targets, axis=-1)
    if regularizer_weight > 0:
        # [..., policy_batch_size, observation_dim]
        observations_policy = observations_regularizer[
            ..., noise_batch_size:, :
        ]
        train_state.policy.actor.eval()
        means_policy, stds_policy = train_state.policy.actor(
            observations_policy
        )
        train_state.policy.actor.train()
        deltas_policy = means_policy + stds_policy * normal(
            fold_in(train_regularizer_key_policy, step), shape=stds_policy.shape
        )
        train_state.policy.critic.eval()
        # [..., policy_batch_size, action_dim]
        policy_actions, _ = train_state.policy.compute_actions(
            deltas=deltas_policy,
            mask=mask_regularizer[..., noise_batch_size:, :],
            means=means_regularizer[..., noise_batch_size:, :, :],
            observations=observations_policy,
        )
        train_state.policy.critic.train()
        if noise_batch_size > 0:
            # [..., noise_batch_size, 1, 1]
            assignments_regularizer = jnp.expand_dims(
                assignments_regularizer[..., :noise_batch_size], (-1, -2)
            )
            # [..., noise_batch_size, num_behaviors, action_dim]
            means_noise = means_regularizer[..., :noise_batch_size, :, :]
            stds_noise = stds_regularizer[..., :noise_batch_size, :, :]
            # [..., noise_batch_size, action_dim]
            means_noise = jnp.squeeze(
                jnp.take_along_axis(
                    means_noise, assignments_regularizer, axis=-2
                ),
                axis=-2,
            )
            stds_noise = jnp.squeeze(
                jnp.take_along_axis(
                    stds_noise, assignments_regularizer, axis=-2
                ),
                axis=-2,
            )
            noise = means_noise + std_multiplier * stds_noise * normal(
                fold_in(train_regularizer_key_noise, step),
                shape=stds_noise.shape,
            )
            # [..., batch_size, action_dim]
            actions_regularizer = jnp.concatenate(
                (noise, policy_actions), axis=-2
            )
        else:
            actions_regularizer = policy_actions
        # [..., batch_size, num_behaviors, action_dim]
        deltas_regularizer = (
            jnp.expand_dims(actions_regularizer, axis=-2) - means_regularizer
        )
        # [..., batch_size, num_behaviors]
        mean_norms = compute_mean_norms(
            deltas=deltas_regularizer, stds=stds_regularizer
        )
        # [..., batch_size, num_behaviors]
        targets_regularizer = (
            baseline_regularizer - gamma * lipschitz_constant * mean_norms
        )
        targets_regularizer = jnp.clip(targets_regularizer, min=min_target)
        # [..., batch_size, num_behaviors]
        ood_mask = compute_ood_mask(
            deltas=deltas_regularizer,
            stds=stds_regularizer,
            threshold=ood_threshold,
        )
        grad_fn = nnx.grad(critic_loss_fn, has_aux=True)
        grads, results = grad_fn(
            train_state.policy.critic,
            actions=actions,
            actions_regularizer=actions_regularizer,
            mask=mask,
            observations=observations,
            observations_regularizer=observations_regularizer,
            ood_mask=ood_mask * mask_regularizer,
            regularizer_weight=regularizer_weight,
            targets=targets,
            targets_regularizer=targets_regularizer,
        )
    else:
        grad_fn = nnx.grad(td_loss_fn, has_aux=True)
        grads, results = grad_fn(
            train_state.policy.critic,
            actions=actions,
            mask=mask,
            observations=observations,
            targets=targets,
        )
    train_state.critic_optimizer.update(grads)
    _, graphstate = nnx.split(train_state)
    return graphstate, results


@partial(jit, static_argnames=("noise_batch_size", "regularizer_weight"))
def train_actor_critic_step(
    actions: Array,
    assignments_regularizer: Array,
    baseline_regularizer: Array,
    dones: Array,
    gamma: float,
    graphdef: nnx.GraphDef[LBPTCTrainState],
    graphstate: nnx.GraphState | nnx.VariableState,
    lipschitz_constant: float,
    mask_actor: Array,
    mask_regularizer: Array,
    means_actor: Array,
    means_regularizer: Array,
    min_target: float,
    mask: Array,
    next_mask: Array,
    next_means: Array,
    next_observations: Array,
    noise_batch_size: int,
    observations: Array,
    observations_actor: Array,
    observations_regularizer: Array,
    ood_threshold: float,
    regularizer_weight: float,
    rewards: Array,
    std_multiplier: float,
    stds_regularizer: Array,
    step: int,
    tau: float,
    train_actor_key: Array,
    train_critic_key: Array,
    train_regularizer_key_noise: Array,
    train_regularizer_key_policy: Array,
):
    graphstate, critic_results = train_critic_step(
        actions=actions,
        assignments_regularizer=assignments_regularizer,
        baseline_regularizer=baseline_regularizer,
        dones=dones,
        gamma=gamma,
        graphdef=graphdef,
        graphstate=graphstate,
        lipschitz_constant=lipschitz_constant,
        mask=mask,
        mask_regularizer=mask_regularizer,
        means_regularizer=means_regularizer,
        min_target=min_target,
        next_mask=next_mask,
        next_means=next_means,
        next_observations=next_observations,
        noise_batch_size=noise_batch_size,
        observations=observations,
        observations_regularizer=observations_regularizer,
        ood_threshold=ood_threshold,
        regularizer_weight=regularizer_weight,
        rewards=rewards,
        std_multiplier=std_multiplier,
        stds_regularizer=stds_regularizer,
        step=step,
        train_critic_key=train_critic_key,
        train_regularizer_key_noise=train_regularizer_key_noise,
        train_regularizer_key_policy=train_regularizer_key_policy,
    )
    train_state = nnx.merge(graphdef, graphstate)
    train_state.policy.critic.eval()
    actor_results = train_actor_step(
        actor_optimizer=train_state.actor_optimizer,
        key=train_actor_key,
        key_data=step,
        mask=mask_actor,
        means=means_actor,
        observations=observations_actor,
        policy=train_state.policy,
    )
    train_state.target_policy.update(model=train_state.policy, tau=tau)
    _, graphstate = nnx.split(train_state)
    return graphstate, actor_results | critic_results


@jit
def train_classifier_step(
    assignments: Array,
    graphdef: nnx.GraphDef[TrainState[MLP]],
    graphstate: nnx.GraphState | nnx.VariableState,
    observations: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    value_and_grad_fn = nnx.value_and_grad(classifier_loss_fn)
    loss, grads = value_and_grad_fn(
        train_state.model, assignments, observations
    )
    train_state.optimizer.update(grads)
    _, graphstate = nnx.split(train_state)
    return graphstate, {"loss/CLS": loss}


@jit
def train_tcae_step(
    actions: Array,
    commitment_cost: float,
    decode_indices: Array,
    decode_mask: Array,
    graphdef: nnx.GraphDef[TrainState[TCAutoEncoder]],
    graphstate: nnx.GraphState | nnx.VariableState,
    latent_indices: Array,
    observations: Array,
    rewards: Array,
    reward_weight: float,
    transition_weight: float,
):
    train_state = nnx.merge(graphdef, graphstate)
    grad_fn = nnx.grad(tcae_loss_fn, has_aux=True)
    grads, results = grad_fn(
        train_state.model,
        actions,
        commitment_cost,
        decode_indices,
        decode_mask,
        latent_indices,
        observations,
        reward_weight,
        rewards,
        transition_weight,
    )
    train_state.optimizer.update(grads)
    _, graphstate = nnx.split(train_state)
    return graphstate, results


@jit
def v_learning_step(
    assignments: Array,
    graphdef: nnx.GraphDef[TrainStateWithTarget[MLP]],
    graphstate: nnx.GraphState | nnx.VariableState,
    observations: Array,
    targets: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    grad_fn = nnx.grad(v_learning_loss_fn, has_aux=True)
    grads, results = grad_fn(
        train_state.model, assignments, observations, targets
    )
    train_state.optimizer.update(grads)
    _, graphstate = nnx.split(train_state)
    return graphstate, results
