from collections import Counter

def select_top_p_tokens_by_group_freq_diff(input_ids_list, scores, k, p, exclude_ids=None):
    """
    Select the top p token IDs with the largest frequency difference between high-score and low-score groups.
    input_ids_list: List[List[int]] or List[Tensor]
    scores: List[float], score for each completion
    k: int, number of samples to take from high/low score groups
    p: int, number of tokens to select
    exclude_ids: set or list, special token IDs to exclude (e.g., pad, eos)
    """
    if exclude_ids is None:
        exclude_ids = set()
    else:
        exclude_ids = set(exclude_ids)
    
    input_ids_list = [ids.tolist() if hasattr(ids, "tolist") else ids for ids in input_ids_list]

    sorted_idx = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
    high_indices = sorted_idx[:k]
    low_indices = sorted_idx[-k:]

    high_tokens = []
    for idx in high_indices:
        high_tokens.extend([t for t in input_ids_list[idx] if t not in exclude_ids])
    low_tokens = []
    for idx in low_indices:
        low_tokens.extend([t for t in input_ids_list[idx] if t not in exclude_ids])

    high_counter = Counter(high_tokens)
    low_counter = Counter(low_tokens)

    all_tokens = set(high_counter.keys()).union(set(low_counter.keys()))

    # calculate frequency difference
    high_total = sum(high_counter.values())
    low_total = sum(low_counter.values())
    freq_diff = {}
    for token in all_tokens:
        freq_high = high_counter[token] / high_total if high_total > 0 else 0
        freq_low = low_counter[token] / low_total if low_total > 0 else 0
        freq_diff[token] = abs(freq_high - freq_low)

    # top p diff tokens
    top_p = sorted(freq_diff, key=freq_diff.get, reverse=True)[:p]
    return top_p