import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
import unittest

from StateActionTracker import StateActionStack
from fix_logic import find_similar_top_k, unfix


class TestFixLogic(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.dfs_tracker = StateActionStack()
        cls.dfs_tracker.push_state([0, 5, 5, 5], {0, 1})
        cls.dfs_tracker.set_fixing_action([0, 5, 5, 5], 0)
        cls.dfs_tracker.push_state([0, 4, 4, 4], {0, 1})
        cls.dfs_tracker.set_fixing_action([0, 4, 4, 4], 1)

    def test_find_closest_valid_pair(self):
        valid_probs = [0.1, 0.1, 0.45, 0.42]
        self.assertEqual(find_similar_top_k(valid_probs, 0.2, 0.1, 0.1), ((2, 0.45), (3, 0.42)))
        probs = [0.25, 0.22, 0.23, 0.24]
        groups = find_similar_top_k(probs, 0.2, 0.1, 0.1)
        self.assertEqual(groups, ((0, 0.25), (3, 0.24), (2, 0.23), (1, 0.22)))

    def test_find_closest_valid_pair_no_group(self):
        valid_probs = [0.5, 0.1, 0.1, 0.2]
        self.assertEqual(find_similar_top_k(valid_probs, 0.2, 0.1, 0.1), None)

    def test_is_fixing(self):
        self.assertEqual(self.dfs_tracker.get_fixing_action([0, 5, 5, 5]), 0)
        self.assertEqual(self.dfs_tracker.get_fixing_action([0, 4, 4, 4]), 1)
        self.assertTrue(self.dfs_tracker.is_fixed([0, 5, 5, 5]))
        self.assertTrue(self.dfs_tracker.is_fixed([0, 4, 4, 4]))

    def test_unfix(self):
        unfix(
            1, self.dfs_tracker, 2, [[[0, 5, 5, 5], [0, 4, 4, 4]]], [-1, -1], [[0, 1]], -1
        )  # Unfix [0, 4, 4, 4] becuase it is added latest and switch to action 0 from 1
        self.assertEqual(self.dfs_tracker.get_fixing_action([0, 4, 4, 4]), 0)
        self.assertEqual(self.dfs_tracker.get_fixing_action([0, 5, 5, 5]), 0)  # [0, 5, 5, 5] should be the same


if __name__ == "__main__":
    suite = unittest.TestSuite()
    suite.addTest(TestFixLogic("test_find_closest_valid_pair"))
    suite.addTest(TestFixLogic("test_find_closest_valid_pair_no_group"))
    suite.addTest(TestFixLogic("test_is_fixing"))
    suite.addTest(TestFixLogic("test_unfix"))
    runner = unittest.TextTestRunner()
    runner.run(suite)
