# experiments/benchmarks.py
"""
Batch-run benchmarks comparing ChronosCore transformer agent vs classical schedulers (RM, EDF)
and a vanilla tabular or MLP DQN baseline.

Outputs:
  - CSV containing per-prompt (taskset) hit rates for each method
  - summary printed to stdout
This is a lightweight runner meant for small-scale reproduction experiments.
"""
import os
import csv
from typing import List
import numpy as np
import torch

from configs import Config
from data.dataset_builder import generate_random_taskset
from chronoscore.schedulers import rate_monotonic_schedule, earliest_deadline_first
from eval.evaluator import compute_hits
from models.transformer_agent import TransformerAgent
from models.value_mapper import MaskedGreedyMapper
from chronoscore.environment import TaskSchedulingEnvironment, IDLE_TASK_ID

def run_benchmark(cfg: Config, n_sets: int = 50, n_tasks: int = 5, n_cores: int = 1, out_csv: str = "experiments/benchmark_results.csv"):
    os.makedirs(os.path.dirname(out_csv) or ".", exist_ok=True)
    rows = []
    device = cfg.device
    # instantiate a randomly initialized transformer agent (we assume already trained in real run)
    agent = TransformerAgent(cfg, n_tasks).to(device)
    mapper = MaskedGreedyMapper(n_tasks, n_cores)

    for sidx in range(n_sets):
        tasks = generate_random_taskset(n_tasks=n_tasks, total_utilization=cfg.total_utilization,
                                       min_period=cfg.min_period, max_period=cfg.max_period)
        env = TaskSchedulingEnvironment(tasks, n_quanta=cfg.n_quanta)
        # classical schedules
        L = env._compute_lcm()
        rm = rate_monotonic_schedule(tasks, L)
        edf = earliest_deadline_first(tasks, L)
        rm_hits = compute_hits(tasks, rm) / (len(tasks) * L)
        edf_hits = compute_hits(tasks, edf) / (len(tasks) * L)

        # agent policy rollout (uses greedy argmax mapping or masked-greedy for multi-core)
        state = env.reset()
        total_hits = 0
        for t in range(env.L):
            qvals = agent(state).detach().cpu()   # [n_tasks+1]
            # map single core case: pick argmax
            if n_cores == 1:
                action = int(qvals.argmax().item())
            else:
                # use per-task q (exclude idle) and masked-greedy mapper to allocate tasks to cores
                per_task_q = qvals[:-1]
                # create task mask for available tasks (front job exists and not done)
                task_mask = [bool(env.jobs.get(task.id) and not env.jobs[task.id][0].is_done()) for task in tasks]
                assignments = mapper.map(per_task_q, task_mask)  # list length n_cores
                # collapse assignments to a chosen action for single time-step: pick first core assignment for decision
                # Here we choose to run the first assigned task or idle if none
                first_assigned = next((a for a in assignments if a != IDLE_TASK_ID), IDLE_TASK_ID)
                action = int(first_assigned)
            state, r, done = env.step(action)
            total_hits += r
            if done:
                break
        agent_hit_rate = total_hits / env.max_reward if env.max_reward > 0 else 0.0

        rows.append({"set": sidx, "rm": rm_hits, "edf": edf_hits, "agent": float(agent_hit_rate)})

    # write CSV
    with open(out_csv, "w", newline="", encoding="utf-8") as fh:
        writer = csv.DictWriter(fh, fieldnames=["set", "rm", "edf", "agent"])
        writer.writeheader()
        for r in rows:
            writer.writerow(r)
    print(f"Wrote benchmark results to {out_csv}")
    return rows

if __name__ == "__main__":
    cfg = Config()
    run_benchmark(cfg, n_sets=20, n_tasks=cfg.n_tasks, n_cores=1)
