from pathlib import Path

import torch
import typer

from hallucinations.dirs import DatasetDir
from hallucinations.features.attention_weights import (
    log_det_attnn_over_dataset,
    yield_stacked_attentions,
)


def main(
    dataset_dir: Path = typer.Option(..., help="Path to the dataset directory"),
) -> None:
    ds_dir = DatasetDir(dataset_dir)

    attn_log_dets: list[torch.Tensor] = []
    for attn_shard in yield_stacked_attentions(ds_dir, remove_padding=True):
        attn_log_dets.extend(log_det_attnn_over_dataset(attn_shard))  # type: ignore[arg-type]

    torch.save(attn_log_dets, ds_dir.log_det_attn_file)


if __name__ == "__main__":
    typer.run(main)
