from mow.dataset.chat import ChatDatasetBuilder


def _tile[T](lst: list[T], n: list[int]) -> list[T]:
    return sum(([elem] * i for elem, i in zip(lst, n)), [])


def _flatten(batch):
    len_histories = [len(h) for h in batch["histories"]]
    ret = {"history": sum(batch["histories"], [])}
    for key in batch:
        if key != "history":
            ret[key] = _tile(batch[key], len_histories)
    return ret


class ChatHistoryMixin(ChatDatasetBuilder):
    def expand(self, desc: str = "Expanding histories"):
        return self.map(
            lambda example: {
                "histories": [
                    example["history"][:end]
                    for end in range(1, len(example["history"]) + 1)
                ]
            },
            batched=False,
            desc=f"{desc} - Step 1",
        ).map(
            _flatten,
            batched=True,
            desc=f"{desc} - Step 2",
            remove_columns=["histories"],
        )
