import argparse
import pandas as pd

def count_tokens_recursive(x):
    """
    Recursively count tokens in nested lists.
    If x is a list:
      - If it's empty, return 0.
      - If none of its elements is a list, return the length of x.
      - Otherwise, sum the token counts recursively.
    If x is not a list, assume it's a token and count as 1.
    """
    return len(x[0][0])

def average_token_length(file_path, column='cont_tokens'):
    # Load the Parquet file into a DataFrame.
    df = pd.read_parquet(file_path)
    
    # Ensure the specified column exists.
    if column not in df.columns:
        raise ValueError(f"The Parquet file does not contain the '{column}' column.")
    
    # Compute token counts per row using the recursive function.
    token_counts = df[column].apply(count_tokens_recursive)
    
    # Calculate and return the average token count per row.
    return token_counts.mean()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Compute the average number of tokens per row in a Parquet file.")
    parser.add_argument("--file_path", type=str, help="Path to the Parquet file")
    parser.add_argument("--column", type=str, default="cont_tokens", help="Name of the column containing tokens (default: 'cont_tokens')")
    args = parser.parse_args()
    
    avg_length = average_token_length(args.file_path, args.column)
    print("Average number of tokens per row:", avg_length)
