from __future__ import annotations

import itertools
from dataclasses import replace
from typing import Dict, List, Any, Tuple

import numpy as np

from .benchmark import BenchmarkConfig, run_one_environment


def cartesian_grid(param_grid: Dict[str, List[Any]]) -> List[Dict[str, Any]]:
    if not param_grid:
        return [dict()]
    keys = list(param_grid.keys())
    vals = [param_grid[k] for k in keys]
    combos = []
    for prod in itertools.product(*vals):
        combos.append({k: v for k, v in zip(keys, prod)})
    return combos


def evaluate_algo_on_env(cfg: BenchmarkConfig, algo_name: str) -> np.ndarray:
    final_regs = []
    for r in range(cfg.R):
        regrets = run_one_environment(cfg, env_seed=cfg.seed + 1000 * r)
        cr = np.cumsum(regrets[algo_name])[-1]
        final_regs.append(cr)
    return np.asarray(final_regs, dtype=float)


def sweep_one_algo(base_cfg: BenchmarkConfig, algo_name: str, grid: Dict[str, List[Any]]) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
    rows: List[Dict[str, Any]] = []
    best = None
    best_mean = float("inf")

    for params in cartesian_grid(grid):
        cfg = replace(base_cfg)
        cfg.algos = [algo_name]
        cfg.algo_params = {algo_name: params}

        finals = evaluate_algo_on_env(cfg, algo_name)
        mean = float(np.mean(finals))
        se = float(np.std(finals, ddof=1) / np.sqrt(len(finals)))
        row = {"algo": algo_name, **params, "mean_final_cumreg": mean, "se": se}
        rows.append(row)

        if mean < best_mean:
            best_mean = mean
            best = params

    assert best is not None
    return best, rows
