"""
author: Anonymous
"""
import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

import numpy as np
from solvers.milp_solver import solve_with_MILP, MILP_Solver, solve_with_MILP_given_partial_schedule, warmstart_MILP

from scheduling.environment import SchedulingEnvironment
from scheduling.agent import Agent
from utils.edge_space import EdgeInstance

import json

from utils.utils import set_seed

import unittest

class MILP_SolverUnitTests(unittest.TestCase):
    def setY(self):
        pass
        
    def test_environment(self):
        set_seed(10)
        save_location = os.path.join(parentdir, 'test/sample_problem_5_10_example.json')
        env = SchedulingEnvironment(save_location, mode='euclidean')
        
        feasible, status, reward, schedule, _ = solve_with_MILP(env)
        
    def test_from_observation(self):
        set_seed(10)
        save_location = os.path.join(parentdir, 'test/sample_problem_5_10_example.json')
        env = SchedulingEnvironment(save_location, mode='euclidean')
        obs, _ = env.reset()
        agent = MILP_Solver()
        agent.set_environment(env)
        task_id, agent_id = agent.get_action(obs)[0]
        self.assertEqual(task_id, 8)
        self.assertEqual(agent_id, 0)
        
        
        obs, _, _, _ = env.step((task_id, agent_id))
        task_id, agent_id = agent.get_action(obs)[0]
        self.assertEqual(task_id, 9)
        self.assertEqual(agent_id, 2)
        
    def test_consistency(self):
        set_seed(10)
        save_location = os.path.join(parentdir, 'test/sample_problem_5_10_example.json')
        env = SchedulingEnvironment(save_location, mode='euclidean')
        obs, _ = env.reset()
        agent = MILP_Solver()
        agent.set_environment(env)
        task_id, agent_id = agent.get_action(obs)[0]
        self.assertEqual(task_id, 8)
        self.assertEqual(agent_id, 0)
        set_seed(11)
        env = SchedulingEnvironment(save_location, mode='euclidean')
        obs, _ = env.reset()
        
        agent = MILP_Solver()
        agent.set_environment(env)
        task_id_2, agent_id_2 = agent.get_action(obs)[0]
        
        self.assertEqual(task_id, task_id_2)
        self.assertEqual(agent_id, agent_id_2)

    def test_assignment_extraction(self):
        set_seed(10)
        save_location = os.path.join(parentdir, 'test/sample_problem_5_10_example.json')
        env = SchedulingEnvironment(save_location, mode='euclidean')
        obs, _ = env.reset()
        out = solve_with_MILP(env)
        # print(out)
        schedule = [(8, 0), (9, 2), (2, 2), (5, 3), (4, 1), (0, 4), (7, 4), (3, 2), (6, 4), (1, 4)]
        self.assertListEqual(out[-2], schedule)
        
    def test_partial_schedule_solver(self):
        set_seed(10)
        env = SchedulingEnvironment(os.path.join(parentdir, 'test/sample_problem_5_10_example.json'), mode='euclidean')
        schedule = [(8, 0), (9, 2), (2, 2), (5, 3), (4, 1), (0, 4), (7, 4), (3, 2), (6, 4), (1, 4)]
        actual_out = solve_with_MILP(env)
        print(f"Actual out: {actual_out}")
        partial_schedule = schedule[:-2] # all but the last one
        output = solve_with_MILP_given_partial_schedule(env, partial_schedule)
        out_schedule = output[-2]
        # sort on agent_id
        schedule = sorted(schedule, key=lambda x: x[1])
        out_schedule = sorted(out_schedule, key=lambda x: x[1])
        self.assertListEqual(out_schedule, schedule)
        
        
        partial_schedule = schedule[:-2] # all but the last two
        output = solve_with_MILP_given_partial_schedule(env, partial_schedule)
        out_schedule = output[-2]
        # sort on agent_id
        out_schedule = sorted(out_schedule, key=lambda x: x[1])
        self.assertListEqual(out_schedule, schedule)
        
        partial_schedule = schedule[:-5] # all but the last half
        output = solve_with_MILP_given_partial_schedule(env, partial_schedule)
        out_schedule = output[-2]
        # sort on agent_id
        print(f"Partial out: {output}")
        out_schedule = sorted(out_schedule, key=lambda x: x[1])
        self.assertAlmostEqual(output[-3], actual_out[-3])
        
    def test_durations(self):
        env = SchedulingEnvironment(os.path.join(parentdir, 'test/sample_problem_5_10_example.json'), mode='euclidean')
        schedule = [(8, 0), (9, 2), (2, 2), (5, 3), (4, 1), (0, 4), (7, 4), (3, 2), (6, 4), (1, 4)]
        actual_out = solve_with_MILP(env)
        
        actual_performance = actual_out[2]
        actual_schedule = actual_out[-2]
        actual_durations = actual_out[-1]
        
        warm_start_schedule = schedule = [(8, 0), (9, 2), (2, 2), (5, 3), (4, 1), (0, 4), (7, 4), (3, 2), (1, 4), (6, 4)]
        warm_start_out = warmstart_MILP(env, actual_schedule)
        
        warm_start_performance = warm_start_out[2]
        warm_start_schedule = warm_start_out[-2]
        warm_start_durations = warm_start_out[-1]
        print(f"---------------- Actual --------|------- Warm Start ----------------")
        print(f"performance: {actual_performance} | {warm_start_performance}")
        print(f"schedule: {actual_schedule} | {warm_start_schedule}")
        print(f"durations: {actual_durations} | {warm_start_durations}")
        
if __name__ == '__main__':
    unittest.main()