# eval_utils.py
from typing import Any, Dict, Iterable, List, Mapping, Sequence, Union, Callable, Tuple
from itertools import product
import time
import numpy as np
from evaluation import evaluate

Spec = Dict[str, Any]
SweepSpace = Mapping[str, Union[Sequence[Any], Callable[[Spec], Sequence[Any]]]]


def _as_list(x):
    if callable(x):  # will be resolved later
        return x
    if isinstance(x, (list, tuple)):
        return list(x)
    return [x]


def expand_sweep(sweep: SweepSpace) -> List[Spec]:
    """
    Expands a sweep description into a list of dict specs.
    Values may be lists/tuples or callables taking the *partial* spec built so far
    and returning a list of choices (lets you encode dependencies).
    """
    keys = list(sweep.keys())
    out: List[Spec] = []

    def rec(i: int, cur: Spec):
        if i == len(keys):
            out.append(cur.copy())
            return
        k = keys[i]
        v = sweep[k]
        # resolve choices possibly depending on what is already chosen in `cur`
        if callable(v):
            choices = list(v(cur))
        else:
            choices = _as_list(v)
        for choice in choices:
            cur[k] = choice
            rec(i + 1, cur)
            cur.pop(k, None)

    rec(0, {})
    return out


def make_eval_sweep(agent_cfg: dict, overrides: Dict[str, Any] | None = None):
    """
    Build a data-driven sweep space. `overrides` can provide runtime edits
    like {'eval_chunk': [1], 'beta': [0.0, 0.1]} to trim the grid.
    """
    sweep = {
        'actor_num_samples': agent_cfg['eval_actor_num_samples_sweep'],
        'actor_type': agent_cfg['eval_actor_type_sweep'],
        # dependent axes:
        'beta': lambda s: [agent_cfg['eval_beta_sweep'][0]]
        if s.get('actor_num_samples') == 1 or s.get('actor_type') == 'bon'
        else agent_cfg['eval_beta_sweep'],
        'q_star_beta': lambda s: [agent_cfg['eval_q_star_beta_sweep'][0]]
        if s.get('actor_num_samples') == 1 or s.get('actor_type') == 'bon'
        else agent_cfg['eval_q_star_beta_sweep'],
        'num_rtg_samples': lambda s: [agent_cfg['eval_num_rtg_samples_sweep'][0]]
        if s.get('actor_num_samples') == 1 or s.get('actor_type') == 'bon'
        else agent_cfg['eval_num_rtg_samples_sweep'],
    }
    # Lightweight overrides (lists or singletons). Callables are allowed too,
    # but for CLI JSON you’ll mainly pass lists.
    if overrides:
        sweep.update(overrides)
    print('sweep:', sweep)
    return sweep


def spec_keyprefix(spec: Dict[str, Any], ordered_keys: List[str]) -> str:
    return '/'.join(f'{k}={spec[k]}' for k in ordered_keys)


def run_one_eval(
    agent,
    env,
    spec: Dict[str, Any],
    num_eval_episodes: int,
    num_video_episodes: int,
    video_frame_skip: int,
    action_dim: int = None,
    agent_name: str = 'evor',
):
    # start = time.time()
    eval_info, trajs, renders = evaluate(
        agent=agent,
        env=env,
        num_eval_episodes=num_eval_episodes,
        num_video_episodes=num_video_episodes,
        video_frame_skip=video_frame_skip,
        action_dim=action_dim,
        sample_actions_kwargs=dict(
            beta=spec['beta'],
            actor_num_samples=spec['actor_num_samples'],
            q_star_beta=spec['q_star_beta'],
            num_rtg_samples=spec['num_rtg_samples'],
            actor_type=spec['actor_type'],
        )
        if agent_name == 'evor'
        else {},
    )
    # total_time = time.time() - start
    # total_states = sum(len(traj['observation']) for traj in trajs)
    # return eval_info, renders, total_time, total_states
    return eval_info, renders, 0, 0
