from flax import nnx
import numpy as np

from offline.bppo.tc.core import compute_batch_means_stds_values
from offline.bppo.tc.modules import BehaviorState
from offline.bppo.tc.types import BPPOTCBatch, HLQLearningBatch, SarsaBatch
from offline.envs import dataset
from offline.lbp.tc.types import AssignmentBatch
from offline.modules.actor.ensemble import GaussianActorEnsembleWithIndices
from offline.modules.mlp import MLP
from offline.types import FloatArray, IntArray, OfflineData
from offline.utils.data import DataLoader, Dataset
from offline.utils.nnx import compute_mlp_outputs


def compute_mask(classifier: MLP, observations: FloatArray, threshold: float):
    logits = compute_mlp_outputs(classifier, observations)
    max_logits = np.max(logits, axis=-1, keepdims=True)
    mask = logits >= max_logits - np.log(threshold)
    return mask


def compute_means_stds_values(
    actor: GaussianActorEnsembleWithIndices,
    assignments: IntArray,
    critic: MLP,
    observations: FloatArray,
    batch_size: int = 256,
) -> tuple[FloatArray, FloatArray, FloatArray]:
    train_state = BehaviorState(actor=actor, critic=critic)
    graphdef, graphstate = nnx.split(train_state)
    loader = DataLoader(
        Dataset(
            AssignmentBatch(assignments=assignments, observations=observations)
        ),
        batch_size=batch_size,
        drop_last=False,
    )
    means_list, stds_list, values_list = [], [], []
    for batch in loader:
        means, stds, values = compute_batch_means_stds_values(
            assignments=batch.assignments,
            graphdef=graphdef,
            graphstate=graphstate,
            observations=batch.observations,
        )
        means_list.append(means)
        stds_list.append(stds)
        values_list.append(values)
    return (
        np.concatenate(means_list),
        np.concatenate(stds_list),
        np.concatenate(values_list),
    )


def prepare_bppo_tc_dataset(
    actor: GaussianActorEnsembleWithIndices,
    assignments: IntArray,
    observations: FloatArray,
    vcritic: MLP,
    batch_size: int = 256,
):
    means, stds, values = compute_means_stds_values(
        actor=actor,
        assignments=assignments,
        batch_size=batch_size,
        critic=vcritic,
        observations=observations,
    )
    return Dataset(
        BPPOTCBatch(
            assignments=assignments,
            means=means,
            observations=observations,
            stds=stds,
            values=values,
        )
    )


def prepare_high_level_q_learning_dataset(
    assignments: IntArray, classifier: MLP, data: OfflineData, threshold: float
):
    # remove end-of-trajectory samples if the trajectory was truncated
    indices = np.logical_not(data.timeouts)
    q_learning_dataset = dataset.prepare_q_learning_dataset(data=data)
    batch = q_learning_dataset.data
    mask = compute_mask(
        classifier=classifier,
        observations=batch.next_observations,
        threshold=threshold,
    )
    return Dataset(
        HLQLearningBatch(
            assignments=assignments[indices],
            dones=batch.dones,
            next_mask=mask,
            next_observations=batch.next_observations,
            observations=batch.observations,
            rewards=batch.rewards,
        )
    )


def prepare_sarsa_dataset(assignments: IntArray, data: OfflineData):
    # remove end-of-trajectory samples if the trajectory was truncated
    indices = np.logical_not(data.timeouts)
    sarsa_dataset = dataset.prepare_sarsa_dataset(data=data)
    batch = sarsa_dataset.data
    return Dataset(
        SarsaBatch(
            actions=batch.actions,
            assignments=assignments[indices],
            dones=batch.dones,
            next_actions=batch.next_actions,
            next_observations=batch.next_observations,
            observations=batch.observations,
            rewards=batch.rewards,
        )
    )
