# chronoscore/environment.py
"""
TaskSchedulingEnvironment for ChronosCore.
Implements job release, step transition, reward calculation and slack quantization.
This wraps Task NamedTuple from data/dataset_builder.
"""
from typing import List, Tuple
from collections import deque
import numpy as np
from data.dataset_builder import Task

IDLE_TASK_ID = -1

class Job:
    def __init__(self, task: Task, release_time: int):
        self.task = task
        self.release_time = release_time
        self.deadline = release_time + task.deadline
        self.remaining = task.exectime
        self.time_until_deadline = task.deadline

    def tick(self, work: bool):
        if work:
            self.remaining -= 1
        self.time_until_deadline -= 1

    def is_done(self) -> bool:
        return self.remaining <= 0

class TaskSchedulingEnvironment:
    def __init__(self, tasks: List[Task], n_quanta: int = 5):
        self.tasks = tasks
        self.n_quanta = n_quanta
        self.max_deadline = max(t.deadline for t in tasks)
        self.quanta_base = max(1, self.max_deadline // n_quanta)
        self.reset()

    def reset(self):
        # each task has a queue of pending jobs (deque)
        self.time = 0
        self.jobs = {t.id: deque([Job(t, 0)]) for t in self.tasks}
        self.L = self._compute_lcm()
        self.state = self._compute_state()
        return self.state

    def _compute_lcm(self):
        from math import gcd
        l = 1
        for t in self.tasks:
            l = l * t.period // gcd(l, t.period)
        return l

    def _compute_state(self):
        # state is tuple of quantized slack values per task (recent job)
        s = []
        for t in self.tasks:
            q = self.jobs.get(t.id)
            if q and len(q) > 0 and (not q[0].is_done()):
                slack = q[0].time_until_deadline - q[0].remaining
                qval = self._quantize(slack)
            else:
                qval = self.max_deadline
            s.append(qval)
        return tuple(s)

    def _quantize(self, slack: int):
        if slack < 0:
            return -self.quanta_base
        return int(self.quanta_base * (max(0, slack) // self.quanta_base))

    def step(self, action: int) -> Tuple[Tuple[int,...], int, bool]:
        """
        Execute one time-step applying action (task id or IDLE_TASK_ID).
        Returns next_state, reward, done flag.
        Reward: +1 on job completion before deadline, -1 on missed deadline for active job (simple scheme).
        """
        reward = 0
        # release new jobs at their periods
        for t in self.tasks:
            if (self.time + 1) % t.period == 0:
                self.jobs.setdefault(t.id, deque()).append(Job(t, self.time + 1))

        # apply action
        if action != IDLE_TASK_ID:
            queue = self.jobs.get(action, deque())
            if queue:
                job = queue[0]
                job.tick(work=True)
                if job.is_done():
                    # completion check relative to job.deadline
                    if self.time + 1 <= job.deadline:
                        reward += 1
                    else:
                        reward -= 1
                    queue.popleft()
            # other jobs age
            for tid, q in self.jobs.items():
                if q:
                    j = q[0]
                    if tid != action:
                        j.tick(work=False)
        else:
            # idle: all front jobs age
            for q in self.jobs.values():
                if q:
                    q[0].tick(work=False)

        self.time += 1
        self.state = self._compute_state()
        done = (self.time >= self.L)
        return self.state, reward, done

    def sample_random_action(self):
        # sample uniformly among tasks that have pending job, else idle
        valid = [t.id for t in self.tasks if self.jobs.get(t.id) and not self.jobs[t.id][0].is_done()]
        if not valid:
            return IDLE_TASK_ID
        import random
        return random.choice(valid)

    def get_current_utilization(self) -> float:
        return sum(t.exectime / t.period for t in self.tasks)
