"""
author: Anonymous
"""
import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

import json
import numpy as np


import unittest
from scheduling.environment import SchedulingEnvironment, generate_environment
from solvers.milp_solver import solve_with_MILP
from utils.utils import set_seed

class TestSchedulingEnv(unittest.TestCase):
    def setUp(self):
        
        pass
    
    def test_init_generate_environment_config(self, filename = 'config/sample_problem_5_10.json'):
        # Execute this first among the tests
        config = {
            'num_agents': [5, 5],
            'num_tasks': [10, 10],
            'map': {
                'width': 100,
                'height': 100
            },
            'task_config': {
                'duration': {
                    'min': 10,
                    'max': 100
                },
                'time_window_percentage': [0.2, 0.8],
                'wait_time_percentage': [0.5, 0.5],
                'wait_time_duration': [10, 100]
            },
            'agent_config': {
                'speed': {
                    'min': 1,
                    'max': 10
                }
            },
            'obstacle_config': {
                'number': [5, 10],
                'radius': [5, 25]
            }
        }
        
        path = os.path.join(parentdir, filename)
        # save the config into json file
        with open(path, 'w') as f:
            json.dump(config, f, indent=4)
        set_seed(10)
        generate_environment(path)
        
    
    def test_generate_environment(self):
        set_seed(10)
        config_file = os.path.join(parentdir, 'test/sample_problem_5_10_config.json')
        self.test_init_generate_environment_config(config_file)
        
    
        save_location = os.path.join(parentdir, 'test/sample_problem_5_10_example.json')
        
        generate_environment(config_file, save_location)
        
        env = SchedulingEnvironment(save_location)
        
        self.assertEqual(env.num_agents, 5)
        self.assertEqual(env.num_tasks, 10)
        self.assertEqual(env.max_speed, 10)
        self.assertEqual(env.width, 100)
        self.assertEqual(env.height, 100)
        self.assertEqual(len(env.tasks), 10)
        self.assertEqual(len(env.agents), 5)
        self.assertEqual(len(env.duration), 50)
        self.assertEqual(len(env.wait_time_constraints), 5)
        self.assertEqual([int(key) for key in env.wait_time_constraints.keys()], [5, 7, 8, 9, 4])
        self.assertEqual([len(env.wait_time_constraints[key]) for key in env.wait_time_constraints.keys()], [1, 1, 1, 1, 1])

    def test_render_save(self):
        set_seed(10)
        config_file = os.path.join(parentdir, 'test/sample_problem_5_10_config.json')
        self.test_init_generate_environment_config(config_file)
        
        save_location = os.path.join(parentdir, 'test/sample_problem_5_10_example.json')
        generate_environment(config_file, save_location)
        
        env = SchedulingEnvironment(save_location)
        env.render('file')
    
    def test_milp_solver(self):
        save_location = os.path.join(parentdir, 'test/sample_problem_5_10_example.json')
        env = SchedulingEnvironment(save_location)
        set_seed(10)        
        feasible, status, makespan, schedule, duration = solve_with_MILP(env)
        
        self.assertEqual(feasible, True)
        self.assertEqual(status, 2)
        self.assertEqual(makespan, 935.0)
        self.assertEqual(len(schedule), 10)
        self.assertEqual(schedule, [(3, 1), (4, 10), (5, 7), (5, 3), (5, 9), (5, 8), (5, 5), (5, 4), (5, 6), (5, 2)])
        
    def test_observation(self):
        save_location = os.path.join(parentdir, 'test/sample_problem_5_10_example.json')
        env = SchedulingEnvironment(save_location)
        obs, _ = env.reset()
        self.assertEqual(len(obs), 14)
        self.assertEqual([key for key in obs], ['agent', 'task', 'state', 'agent_to_task', 'travel', 'wait_time', 'assigned', 'agent_to_state', 'task_to_state', 'task_assignment', 'agent_to_task_assignment', 'task_to_task_assignment', 'task_assignment_to_task_assignment', 'task_to_task_select'])
        new_obs = env.get_observation()
        for key in obs:
            if key in ['agents', 'tasks']:
                self.assertTrue(key in new_obs)
                self.assertTrue(np.equal(obs[key], new_obs[key]).all(), f"Key: {key} - {obs[key]} - {new_obs[key]}")
            elif key in ['agent_task', 'wait_time', 'assigned', 'agent_task_task_travel_time']:
                self.assertTrue(key in new_obs)
                self.assertTrue(np.equal(obs[key].source_nodes, new_obs[key].source_nodes).all(), f"Key: {key} - {obs[key].source_nodes} - {new_obs[key].source_nodes}")
                self.assertTrue(np.equal(obs[key].target_nodes, new_obs[key].target_nodes).all(), f"Key: {key} - {obs[key].target_nodes} - {new_obs[key].target_nodes}")
                self.assertTrue(np.equal(obs[key].edge_features, new_obs[key].edge_features).all(), f"Key: {key} - {obs[key].edge_features} - {new_obs[key].edge_features}")
            
        # print(obs)
        
if __name__ == '__main__':
    unittest.main()