import numpy as np
from typing import Sequence, Literal, Callable, Any, Iterable
import json
from scipy.interpolate import interp1d
from pathlib import Path


type Results[K] = dict[K, Result]
type MultResults = Results[tuple[int, int]]
type PathLike = Path | str


class Result(dict[str, Any]):

    @property
    def count(self) -> int:
        return self["count"]
    
    @count.setter
    def count(self, value):
        self["count"] = value


def get_eval_results[K](path: PathLike, kmap: Callable[[str], K] = lambda k: k) -> Results[K]:
    with open(path, 'rt') as f:
        data = json.load(f)
    results: Results[K] = {}
    assert isinstance(data, dict)
    for k, v in data.items():
        assert isinstance(v, dict)
        if kmap is not None:
            try:
                k = kmap(k)
            except ValueError:
                continue
        results[k] = Result(**v)
    return results


def tabulate_results[ExperimentKey, ResultKey](
    settings: Sequence[ExperimentKey],
    colmap: Callable[[str], Iterable[ResultKey]],
    experiments: Sequence[dict[ExperimentKey | Literal["path"], Any]],
    root: PathLike | None = None,
    stat: str | Callable[[Result], float]  = "score",
    weight: str | Callable[[Result], float] = "count",
    percentage: bool = False,
) -> list[dict[ExperimentKey | ResultKey, Any]]:
    
    entries: list[dict[ExperimentKey | ResultKey, Any]] = []

    for i, experiment in enumerate(experiments):
        path = Path(experiment["path"])
        if root is not None:
            path = Path(root) / path
        results = get_eval_results(path)
        
        weights: dict[ResultKey, float] = {}
        sums: dict[ResultKey, float] = {}

        for k, r in results.items():
            v = r.get(stat) if isinstance(stat, str) else stat(r)
            w = r.get(weight, 0.) if isinstance(weight, str) else weight(r)
            if w == 0 or v is None or np.isnan(v) or np.isnan(w):
                continue
            cols = colmap(k)
            for col in cols:
                weights[col] = weights.get(col, 0) + w
                sums[col] = sums.get(col, 0) + v * w

        entry: dict[ExperimentKey | ResultKey, Any] = {}
        for col in settings:
            entry[col] = experiment.get(col)
        for col, w in weights.items():
            s = sums[col]
            v = s / w if w > 0 else np.nan
            if percentage:
                v = 100 * v
            entry[col] = v

        entries.append(entry)
    
    return entries




def smooth_points(
    xs: Sequence[float] | np.ndarray,
    ys: Sequence[float] | np.ndarray,
    *,
    n: int = 100,
    alpha: float = 0.5,
) -> tuple[np.ndarray, np.ndarray]:
    """
    Compute a smooth line based on weighted sum of y values using inverse distance.

    Parameters:
    - xs: list of x coordinates
    - ys: list of y coordinates
    - num_points: number of smooth points to generate

    Returns:
    - x_smooth: smoothed x values
    - y_smooth: smoothed y values
    """
    
    # Convert input lists to numpy arrays for easier manipulation
    if not isinstance(xs, np.ndarray):
        xs = np.array(xs)
    if not isinstance(ys, np.ndarray):
        ys = np.array(ys)
    # Define the range of x values for the smooth line
    new_xs = np.linspace(xs.min(), xs.max(), n)
    
    interp_func = interp1d(xs, ys, kind='linear')
    ys = interp_func(new_xs)

    new_ys = np.zeros_like(new_xs)
    # Calculate the weighted sum for each x in x_smooth
    for i, x in enumerate(new_xs):
        # Calculate weights based on inverse distance
        weights = alpha ** np.abs(new_xs - x)
        weights /= np.sum(weights)  # Normalize weights
        new_ys[i] = np.sum(weights * ys)
    
    return new_xs, new_ys
