"""
author: Anonymous
"""
import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

import torch

from scheduling.environment import SchedulingEnvironment
from models.graph_scheduler import GraphSchedulerCritic

import unittest
class TestGraphCritic(unittest.TestCase):
    def setUp(self):
        self.env = SchedulingEnvironment("test/sample_problem_5_10_example.json")
        
        pass
    
    def test_initialization(self):
        critic = GraphSchedulerCritic()
        self.assertEqual(critic.outputs, ['value'])
    
    def test_forward(self):
        critic = GraphSchedulerCritic()
        observation, _ = self.env.reset()
        output = critic.forward(observation)
        self.assertEqual(output.shape, torch.Size([1, 1]))
        
    # def test_agent_centric_model(self):
    #     scheduler = GraphSchedulerCritic(outputs = ['agent'])
    #     observation, _ = self.env.reset()
    #     output = scheduler.forward(observation, mode='agent')
    #     print(f"Agent Output: {output}")
    #     print(f"Agent Softmax: {torch.softmax(output, dim=0)}")
    #     self.assertEqual(output.shape, torch.Size([5, 1]))
    #     unit_selected = torch.argmax(output).item()
    #     dependent_action = {'agent': unit_selected}
    #     output = scheduler.forward(observation, mode='task', dependent_action=dependent_action)
    #     print(f"Task Output: {output}")
    #     print(f"Task Softmax: {torch.softmax(output, dim=0)}")
    #     self.assertEqual(output.shape, torch.Size([10, 1]))
        
if __name__ == '__main__':
    unittest.main(argv=[''], exit=False)