from typing import cast

import numpy as np
import pandas as pd


def get_string_tokens(
    result: pd.DataFrame,
) -> list[str]:
    """From the index of a memorization_log dataframe, extract the tokens
    of the (single) string that was memorized."""
    index = result.index
    strings = index.get_level_values("string").unique()
    assert len(strings) == 1 and strings[0] == 0
    tokens = index.get_level_values("character")[
        index.get_level_values("epoch") == 0
    ]
    return cast(list[str], tokens)


def get_max_epoch(
    result: pd.DataFrame,
) -> int:
    """From the index of a memorization_log dataframe, extract the maximum
    epoch number."""
    return int(cast(float, result.index.get_level_values("epoch").max()))


def reindex_positionwise(
    result: pd.DataFrame,
) -> pd.DataFrame:
    """Replace the 'character' level of the index of the memorization_log
    dataframe with a 'token_index' level that has the position of each
    token in the string."""
    max_epoch = get_max_epoch(result)
    tokens = get_string_tokens(result)
    token_indices = np.array(list(range(len(tokens))) * (max_epoch + 1))
    extended_index = pd.MultiIndex.from_arrays(
        [
            result.index.get_level_values("epoch"),
            token_indices,
        ],
        names=["epoch", "token_index"],
    )
    return result.reset_index(level=["character"]).set_index(extended_index)
