"""
author: Anonymous
"""
import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

import torch

from scheduling.environment import SchedulingEnvironment
from models.graph_scheduler import GraphModelWithEdge as GraphScheduler

import unittest
class TestGraphScheduler(unittest.TestCase):
    def setUp(self):
        self.env = SchedulingEnvironment("test/sample_problem_5_10_example.json")
        
        pass
    
    def test_initialization(self):
        scheduler = GraphScheduler(outputs = ['agent', 'task'])
        self.assertEqual(scheduler.outputs, ['agent', 'task'])
        self.assertEqual(scheduler.output_nodes, ['agent_select', 'task_select'])
    
    def test_forward(self):
        scheduler = GraphScheduler(outputs = ['agent', 'task'])
        observation, _ = self.env.reset()
        output = scheduler.forward(observation, mode='agent')
        self.assertEqual(output.shape, torch.Size([5, 1]))
        unit_selected = torch.argmax(output).item()
        dependent_action = {'agent': unit_selected}
        output = scheduler.forward(observation, mode='task', dependent_action=dependent_action)
        self.assertEqual(output.shape, torch.Size([10, 1]))
        
    # def test_split_scheduling(self):
    #     agent_scheduler = GraphScheduler(outputs = ['agent'])
    #     task_scheduler = GraphScheduler(outputs = ['task'])
    #     observation, _ = self.env.reset()
    #     output = agent_scheduler.forward(observation, mode='agent')
    #     self.assertEqual(output.shape, torch.Size([5, 1]))
    #     output = task_scheduler.forward(observation, mode='task')
    #     self.assertEqual(output.shape, torch.Size([10, 1]))
    
    def test_agent_centric_model(self):
        scheduler = GraphScheduler(outputs = ['agent'])
        observation, _ = self.env.reset()
        output = scheduler.forward(observation, mode='agent')
        print(f"Agent Output: {output}")
        print(f"Agent Softmax: {torch.softmax(output, dim=0)}")
        self.assertEqual(output.shape, torch.Size([5, 1]))
        unit_selected = torch.argmax(output).item()
        dependent_action = {'agent': unit_selected}
        
        self.assertRaises(KeyError, scheduler.forward, observation, mode='task', dependent_action=dependent_action)
        
        # print(f"Task Output: {output}")
        # print(f"Task Softmax: {torch.softmax(output, dim=0)}")
        # self.assertEqual(output.shape, torch.Size([10, 1]))
        
    def test_task_first_model(self):
        task_scheduler = GraphScheduler(outputs=['task'])
        observation, _ = self.env.reset()
        output = task_scheduler.forward(observation, mode='task')
        
        task_selected = torch.argmax(output).item()
        dependent_action = {'task': task_selected}
        
        agent_scheduler = GraphScheduler(outputs=['agent'])
        output = agent_scheduler(observation, mode='agent', dependent_action=dependent_action)
        
if __name__ == '__main__':
    unittest.main(argv=[''], exit=False)