# eval/evaluator.py
"""
Evaluation utilities to run a trained policy on tasksets and compute hit-rate and other metrics.
"""
from typing import List, Tuple
import numpy as np
from data.dataset_builder import Task
from chronoscore.environment import TaskSchedulingEnvironment, IDLE_TASK_ID
from chronoscore.schedulers import rate_monotonic_schedule, earliest_deadline_first

def evaluate_policy(policy, tasks: List[Task], n_quanta: int = 5, return_schedule: bool = False):
    env = TaskSchedulingEnvironment(tasks=tasks, n_quanta=n_quanta)
    state = env.reset()
    schedule = []
    total_reward = 0
    for t in range(env.L):
        with torch.no_grad():
            qvals = policy(state)
            action = int(qvals.argmax().item())
        if env.no_jobs_to_do():
            action = IDLE_TASK_ID
        state, r, done = env.step(action)
        schedule.append(action)
        total_reward += r
        if done:
            break
    hit_rate = total_reward / env.max_reward if env.max_reward > 0 else 0.0
    if return_schedule:
        return hit_rate, schedule
    return hit_rate

def evaluate_classics(tasks: List[Task]) -> dict:
    L = TaskSchedulingEnvironment(tasks, n_quanta=5)._compute_lcm()
    rm_hits = compute_hits(tasks, rate_monotonic_schedule(tasks, L))
    edf_hits = compute_hits(tasks, earliest_deadline_first(tasks, L))
    total = len(tasks) * L
    return {"RM": rm_hits / total, "EDF": edf_hits / total}

# helper compute_hits copied from utils.metrics or utils; keep minimal here
def compute_hits(tasks: List[Task], schedule: List[int]) -> int:
    rem_exec = {t.id: t.exectime for t in tasks}
    next_deadline = {t.id: t.deadline for t in tasks}
    hit_count = 0
    for time, action in enumerate(schedule):
        if action != IDLE_TASK_ID:
            rem_exec[action] -= 1
            if rem_exec[action] == 0 and time + 1 <= next_deadline[action]:
                hit_count += 1
        for t in tasks:
            if (time + 1) % t.period == 0:
                rem_exec[t.id] = t.exectime
                next_deadline[t.id] = time + 1 + t.deadline
    return hit_count
