from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, Iterable, Callable, List, Tuple


def run_ablation(train_fn: Callable[[Dict[str, Any]], Dict[str, float]], grid: Dict[str, Iterable]) -> Dict[str, Dict[str, float]]:
    results: Dict[str, Dict[str, float]] = {}
    keys = list(grid.keys())
    from itertools import product

    for values in product(*[grid[k] for k in keys]):
        cfg = {k: v for k, v in zip(keys, values)}
        tag = ",".join([f"{k}={v}" for k, v in cfg.items()])
        res = train_fn(cfg)
        results[tag] = res
    return results


def topk(results: Dict[str, Dict[str, float]], metric: str, k: int = 5, reverse: bool = True) -> List[Tuple[str, float]]:
    vals = []
    for tag, res in results.items():
        if metric in res:
            vals.append((tag, res[metric]))
    vals.sort(key=lambda x: x[1], reverse=reverse)
    return vals[:k]


def summarize(results: Dict[str, Dict[str, float]], metrics: List[str]) -> Dict[str, float]:
    out: Dict[str, float] = {}
    for m in metrics:
        xs = [r[m] for r in results.values() if m in r]
        if xs:
            out[f"{m}/mean"] = float(sum(xs) / len(xs))
            out[f"{m}/max"] = float(max(xs))
            out[f"{m}/min"] = float(min(xs))
    return out


def _demo():
    def train_fn(cfg):
        return {"metric": float(hash(str(cfg)) % 100) / 100.0}

    grid = {"a": [1, 2, 3], "b": ["x", "y"]}
    res = run_ablation(train_fn, grid)
    print(topk(res, "metric", k=3))
    print(summarize(res, ["metric"]))


if __name__ == "__main__":
    _demo()
                
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
