# misc stuff for shortening notebooks

import pandas as pd
import numpy as np


def infer_length_column(base_col_name, dataframe, args=None):
    # in order of preference
    # the count computed at detection time is ideal, denoted `_num_tokens_scored`
    # else for the outputs it's generation-time token ct
    # and for the baseline its the initial ct base on tokenization and slice
    # both now called `_length`

    if args.ignore_repeated_ngrams:
        # if we're ignoring repeated ngrams, then we need to use the length column
        # since the num_tokens_scored column will be wrong/short
        # though this isn't a perfect solution bc there can be retokenization differences
        col_suffixes = ["_length"]
    else:
        col_suffixes = ["_num_tokens_scored", "_length"]

    for suf in col_suffixes:
        length_column_name = f"{base_col_name}{suf}"
        if length_column_name in dataframe.columns:
            return length_column_name

    raise ValueError(
        f"Could not find length column for {base_col_name}. Note, `_num_tokens_generated` suffix is deprecated in favor of `_length`."
    )


def filter_text_col_length(
    df, text_col_name=None, count_suffix="_num_tokens_scored", upper_T=205, lower_T=195
):
    assert text_col_name is not None
    text_col_prefix = text_col_name
    text_col_name = text_col_prefix + count_suffix

    # length filtering
    orig_len = len(df)

    df = df[(df[text_col_name] >= lower_T)]
    df = df[(df[text_col_name] <= upper_T)]

    print(f"Dropped {orig_len-len(df)} rows filtering {text_col_prefix}, new len {len(df)}")

    return df
