# chronoscore/q_network.py
"""
Q-network head that consumes per-task encoder features and outputs Q-values per action.
For single-core scheduling, actions = task indices + idle. For multi-core, mapping logic goes elsewhere.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

class QHead(nn.Module):
    def __init__(self, latent_dim: int, n_tasks: int):
        super().__init__()
        self.n_tasks = n_tasks
        self.latent_dim = latent_dim
        # map per-task feature to scalar Q for that task
        self.head = nn.Sequential(
            nn.Linear(latent_dim, latent_dim//2),
            nn.ReLU(),
            nn.Linear(latent_dim//2, 1)
        )
        # idle action head
        self.idle_head = nn.Sequential(
            nn.Linear(n_tasks * latent_dim, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, 1)
        )

    def forward(self, task_feats: torch.Tensor) -> torch.Tensor:
        """
        task_feats: [n_tasks, latent_dim]
        returns: q_values tensor of length n_tasks + 1 (last is idle)
        """
        per_task_q = self.head(task_feats).squeeze(-1)  # [n_tasks]
        idle_q = self.idle_head(task_feats.flatten()).squeeze(-1)  # scalar
        q_all = torch.cat([per_task_q, idle_q.unsqueeze(0)], dim=0)  # [n_tasks+1]
        return q_all
