"""Heterogenous Graph Attention Network (HetGAT) solver for Scheduling Problem"""

import numpy as np

import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical

from scheduling.environment import SchedulingEnvironment
from models.graph_scheduler import GraphSchedulerCritic


class HetGatCritic(nn.Module):
    def __init__(self, mode='attention', num_heads = 8):
        """Heterogenous Graph Attention Network (HetGAT) solver for Scheduling Problem
        Args:
            mode (str): Mode of the graph, must be one of ['no_edge', 'edge', 'attention']
        """
        super().__init__()
        self.critic = GraphSchedulerCritic(num_heads=num_heads, task_assignment_node=True, graph=mode)
        
    def get_value(self, x, action):
        value, _, _, _ = self.critic(x, action)
        return value
    
    def get_q_values(self, x):
        value, _, _, _ = self.critic(x)
        return value
    
    