import math

from flax import nnx
import numpy as np

from offline.lbp.tc.core import compute_batch_mask_means_stds_values
from offline.lbp.tc.modules import BehaviorState
from offline.lbp.tc.types import (
    ActorBatch,
    QLearningBatch,
    RegularizerBatch,
    TCBatch,
    VLearningBatch,
)
from offline.lbp.utils import compute_discounted_returns
from offline.modules.actor.ensemble import GaussianActorEnsemble
from offline.modules.mlp import MLP
from offline.types import BoolArray, FloatArray, IntArray, OfflineData
from offline.utils.data import ArrayDataLoader, Dataset, TrajectoryDataset
from offline.utils import dataset


def compute_mask_means_stds_values(
    actor: GaussianActorEnsemble,
    classifier: MLP,
    critic: MLP,
    observations: FloatArray,
    threshold: float,
    batch_size: int = 256,
) -> tuple[BoolArray, FloatArray, FloatArray, FloatArray]:
    train_state = BehaviorState(
        actor=actor, classifier=classifier, critic=critic
    )
    graphdef, graphstate = nnx.split(train_state)
    loader = ArrayDataLoader(
        batch_size=batch_size, data=observations, drop_last=False
    )
    mask_list, means_list, stds_list, values_list = [], [], [], []
    for batch in loader:
        mask, means, stds, values = compute_batch_mask_means_stds_values(
            graphdef=graphdef,
            graphstate=graphstate,
            log_threshold=math.log(threshold),
            observations=batch,
        )
        mask_list.append(mask)
        means_list.append(means)
        stds_list.append(stds)
        values_list.append(values)
    return (
        np.concatenate(mask_list),
        np.concatenate(means_list),
        np.concatenate(stds_list),
        np.concatenate(values_list),
    )


def normalize_rewards(
    assignments: IntArray, data: OfflineData, gamma: float, eps: float = 1e-5
):
    dones = np.logical_or(data.dones, data.terminals)
    discounted_returns = compute_discounted_returns(
        dones=dones, gamma=gamma, rewards=data.rewards
    )
    max_mean_discounted_return = -np.inf
    for index in range(np.max(assignments) + 1):
        max_mean_discounted_return = max(
            np.abs(np.mean(discounted_returns[assignments == index])),
            max_mean_discounted_return,
        )
    return 1 / float(max_mean_discounted_return + eps)


def prepare_actor_q_learning_regularizer_dataset(
    actor: GaussianActorEnsemble,
    assignments: IntArray,
    classifier: MLP,
    critic: MLP,
    data: OfflineData,
    offset: float,
    threshold: float,
    batch_size: int = 256,
):
    # remove end-of-trajectory samples if the trajectory was truncated
    indices = np.logical_or(np.logical_not(data.dones), data.terminals)
    mask, means, stds, values = compute_mask_means_stds_values(
        actor=actor,
        batch_size=batch_size,
        classifier=classifier,
        critic=critic,
        observations=data.observations,
        threshold=threshold,
    )
    actor_batch = ActorBatch(
        mask=mask, means=means, observations=data.observations
    )
    regularizer_batch = RegularizerBatch(
        assignments=assignments,
        baseline=values - offset,
        mask=mask,
        means=means,
        observations=data.observations,
        stds=stds,
    )
    next_mask = np.roll(mask, -1, axis=0)
    assert np.all(next_mask[:-1] == mask[1:])
    next_means = np.roll(means, -1, axis=0)
    assert np.all(next_means[:-1] == means[1:])
    batch = dataset.prepare_q_learning_dataset(data)
    q_learning_batch = QLearningBatch(
        actions=batch.data.actions,
        dones=batch.data.dones,
        mask=mask[indices],
        next_mask=next_mask[indices],
        next_means=next_means[indices],
        next_observations=batch.data.next_observations,
        observations=batch.data.observations,
        rewards=batch.data.rewards,
    )
    return (
        Dataset(actor_batch),
        Dataset(q_learning_batch),
        Dataset(regularizer_batch),
    )


def prepare_tc_dataset(data: OfflineData, filter_trajectories: bool):
    batch = TCBatch(
        actions=data.actions,
        observations=data.observations,
        rewards=np.expand_dims(data.rewards, 1),
    )
    return TrajectoryDataset(
        data=batch, dones=data.dones, filter_trajectories=filter_trajectories
    )


def prepare_v_learning_dataset(
    assignments: IntArray, data: OfflineData, gamma: float
):
    dones = np.logical_or(data.dones, data.terminals)
    discounted_returns = compute_discounted_returns(
        dones=dones, gamma=gamma, rewards=data.rewards
    )
    return Dataset(
        VLearningBatch(
            assignments=assignments,
            observations=data.observations,
            targets=discounted_returns,
        )
    )
