from collections.abc import Iterable

from pyspark.sql import functions as F
from pyspark.sql import Window
from pyspark.sql.dataframe import DataFrame


def collect_lists(
    df: DataFrame,
    group_by: str | Iterable[str],
    order_by: str | Iterable[str],
) -> DataFrame:
    """Collect lists and add auxiliary columns.

    The function collect all sequence elements in the dataframe in lists grouping by
    the `group_by` columns and ordering by the `order_by` columns. It also computes the
    auziliary information: sequence lengths and the last value(s) in the `order_by`
    column(s). The latter columns are named as `order_by` columns with prefix "_last_",
    the column containing sequence lengths has name "_seq_len".

    Args:
        df: DataFrame containing all sequences.
        group_by: column(s) identifying a sequence.
        order_by: column(s) used for ordering sequences.

    Return:
        a dataframe with collected lists and auxiliary columns.
    """

    if isinstance(order_by, str):
        order_by = (order_by,)
    order_by = list(order_by)

    if isinstance(group_by, str):
        group_by = (group_by,)
    group_by = list(group_by)

    seq_cols = list(set(df.columns) - set(group_by) - set(order_by))

    return (
        df.select(*group_by, F.struct(*order_by, *seq_cols).alias("s"))
        .groupBy(*group_by)
        .agg(F.sort_array(F.collect_list("s")).alias("s"))
        .select(
            *group_by,
            *map(lambda c: "s." + c, order_by + seq_cols),
            F.size("s").alias("_seq_len"),
            *map(lambda c: F.element_at("s." + c, -1).alias("_last_" + c), order_by),
        )
    )


def cat_freq(df: DataFrame, cols: Iterable[str]) -> list[DataFrame]:
    """Computes the value frequency ranks for columns (0 for the most frequent value).

    Args:
        df: dataframe.
        cols: columns for which to count the occurencies.

    Returns:
        list of dataframes with values frequency ranks for each column.
    """

    val_counts = []
    for col in cols:
        val_counts.append(
            df.select(col)
            .groupBy(col)
            .count()
            .select(
                col,
                (
                    F.row_number().over(
                        # dummy partition F.lit(0) to suppress WindowExec warning
                        # "No Partition Defined for Window operation! ..."
                        Window.partitionBy(F.lit(0)).orderBy(F.col("count").desc())
                    )
                    - 1  # start from 0 instead of 1
                ).alias("freq_rank"),
            )
        )

    return val_counts
