from dataclasses import dataclass
from typing import Literal

import pandas as pd

from lib_llm.eval.memorization.dynamics.utils import reindex_positionwise


@dataclass
class Substring:
    token_ids: list[int]
    tokens: list[str]
    # The start positions of the occurences of the snippet in the string
    positions: list[int]


KeepSpec = Literal["all", "first", "remaining"]


def filter_for_substrings(
    result: pd.DataFrame,
    substrings: list[Substring],
    keep: KeepSpec,
) -> pd.DataFrame:
    if keep == "all":
        return result

    position_indexed_result = reindex_positionwise(result)
    position_index = position_indexed_result.index.get_level_values(
        "token_index"
    )

    positions_to_retain = []
    for substring in substrings:
        first_start_pos = substring.positions[0]
        end_pos = first_start_pos + len(substring.token_ids)
        positions_to_retain.extend(list(range(first_start_pos, end_pos)))
    if keep == "remaining":
        positions_to_retain = set(position_index).difference(
            positions_to_retain
        )

    filtered_result = result.loc[position_index.isin(positions_to_retain)]
    return filtered_result
