from typing import Callable, TypeVar, Mapping

import numpy as np


metric_registry = []
NumMap = Mapping[str, float | int | bool]
NM = TypeVar("NM", bound=NumMap)


def register(
    func: Callable[[list[np.ndarray]], NM]
) -> Callable[[list[np.ndarray]], NM]:
    metric_registry.append(func)
    return func


@register
def basic(data: list[np.ndarray]) -> NumMap:
    catted = np.concatenate(data)
    return {
        "Token Count": len(catted),
        "Line Count": len(data),
        "Tokens per Line": len(np.concatenate(data)) / len(data),
        "Tokens per Line SD": float(np.std([len(x) for x in data])),
        "Unique Tokens": len(np.unique(catted)),
        "Unique Lines": len({hash(l.data.tobytes()) for l in data}),
    }


@register
def entropy_1gram(data: list[np.ndarray]) -> NumMap:
    catted = np.concatenate(data)
    counts = np.unique(catted, return_counts=True)[1]
    normed = counts / counts.sum()
    ent = -(normed * np.log2(normed)).sum()
    return {
        "1-gram Entropy": ent,
        "1-gram Normalized Entropy": ent / np.log2(len(counts)),
    }


@register
def conditional_entropy_2gram(data: list[np.ndarray]) -> NumMap:
    windows = [
        np.lib.stride_tricks.sliding_window_view(x, (2,), axis=-1)
        for x in data
        if len(x) >= 2
    ]
    catted = np.concatenate(windows)
    counts = np.unique(catted, axis=0, return_counts=True)[1]
    normed = counts / counts.sum()
    ent = -(normed * np.log2(normed)).sum()
    return {
        "2-gram Entropy": ent,
        "2-gram Conditional Entropy": ent - entropy_1gram(data)["1-gram Entropy"],
    }


@register
def entropy_per_line(data: list[np.ndarray]) -> NumMap:
    catted = np.concatenate(data)
    unique_values, counts_dense = np.unique(catted, return_counts=True)
    counts = np.zeros(np.max(unique_values) + 1, dtype=float)
    counts[unique_values] = counts_dense
    # Smoothing
    counts += 1e-10
    logged = np.log2(counts / counts.sum())
    bpm = -np.mean([logged[msg].sum() for msg in data])
    return {"Entropy per Line": bpm}


@register
def end_of_sentence(data: list[np.ndarray]) -> NumMap:
    """Detect if end-of-sentence token present."""
    candidates = {l[-1] for l in data}
    if len(candidates) == 1:
        c = list(candidates)[0]

        eos_only_at_end = lambda x: all(x[:-1] != c)
        all_eos_rightward = lambda x: all(
            (x[i] != c) or all(x[i:] == c) for i in range(len(x))
        )
        eos_present = all(eos_only_at_end(x) or all_eos_rightward(x) for x in data)

        same_length = len({len(l) for l in data}) == 1
        padding_exists = any(any(x[:-1] == c) for x in data)
        eos_padding = eos_present and same_length and padding_exists
    else:
        eos_present = False
        eos_padding = False

    return {
        "EoS Token Present": eos_present,
        "EoS Padding": eos_padding,
    }
