import numpy as np

from offline.bppo.tc.iql.types import HLQLearningBatch
from offline.types import IntArray, OfflineData
from offline.utils.data import Dataset
from offline.utils.dataset import prepare_q_learning_dataset


def prepare_high_level_q_learning_dataset(
    assignments: IntArray, data: OfflineData
):
    # remove end-of-trajectory samples if the trajectory was truncated
    indices = np.logical_or(np.logical_not(data.dones), data.terminals)
    q_learning_dataset = prepare_q_learning_dataset(data=data)
    batch = q_learning_dataset.data
    return Dataset(
        HLQLearningBatch(
            assignments=assignments[indices],
            dones=batch.dones,
            next_observations=batch.next_observations,
            observations=batch.observations,
            rewards=batch.rewards,
        )
    )
