import unittest
from reconfiguration import *
from vote_reader import read_file
from os import path

class MyTestCase(unittest.TestCase):

    def committee_score(self, m, voters, committee, rule):
        """

        Parameters
        ----------
        m : int
            The number of alternatives.
        voters : array[set[int]]
            The voters in the instance.
        committee : set[int]
            The committee whose score is computed
        rule : str
            The voting rule.

        Returns
        -------
        float
            The score of the committee according to the voting rule,
        """
        if rule == "cc":
            vector = [1] + [0 for _ in m]
        if rule == "pav":
            vector = [(1 / (x + 1)) for x in m]
        score = 0
        for vote in voters:
            index = 0
            for c in committee:
                if c in vote:
                    score += vector[index]
                    index += 1
        return score

    def construct_trivial(self, m, k, rule, opt_req=1):
        """
        Construct a ReconfigurationSolverExhaustive with one voter who approves one alternative and m - 1 other
            alternatives.

        Parameters
        ----------
        m : int
            The number of alternatives
        k : int
            Committee size
        rule : str
            The voting rule
        opt_req : float
            Optimality requirement

        Returns
        -------
        ReconfigurationSolverExhaustive
            The solver object.

        """
        voters = [{0}]
        solver = ReconfigurationSolverExhaustive(m, voters, rule, k, symdiff = 2,
                                                 optimality_requirement=opt_req)
        return solver

    def construct_slightly_nontrivial(self, m, k, rule, opt_req=1):
        """
        Construct a ReconfigurationSolverExhaustive with a voter who approves alternatives {0, 1} and a voter who
        approves alternatives {0, 2} and m - 3 other alternatives.

        Parameters
        ----------
        m : int
            The number of alternatives
        k : int
            Committee size
        rule : str
            The voting rule
        opt_req : float
            Optimality requirement

        Returns
        -------
        ReconfigurationSolverExhaustive
            The solver object.

        """
        voters = [{0, 1}, {0, 2}]
        solver = ReconfigurationSolverExhaustive(m, voters, rule, k, symdiff=2,
                                                 optimality_requirement=opt_req)
        return solver

    def read_test_2(self, opt_req=1):
        voters, m, renames = read_file(path.join("test_input", "test_data_2.txt"))
        solver = ReconfigurationSolverExhaustive(m, voters, rule="cc", k=3, symdiff=2, optimality_requirement=opt_req)
        return solver, renames

    def read_test_4(self, opt_req=1):
        voters, m, renames = read_file(path.join("test_input", "test_data_4.txt"))
        solver = ReconfigurationSolverExhaustive(m, voters, rule="pav", k=4, symdiff=2, optimality_requirement=opt_req)
        return solver, renames

    def test_gives_one_optimal_committee1(self):
        solver = self.construct_trivial(1,1,"cc")
        expected = [{0}]
        self.assertEqual(solver.committees, expected)

    def test_gives_one_optimal_committee2(self):
        solver = self.construct_trivial(2,1,"cc")
        expected = [{0}]
        self.assertEqual(solver.committees, expected)

    def test_gives_one_optimal_committee3(self):
        solver = self.construct_trivial(2,2,"cc")
        expected = [{0, 1}]
        self.assertEqual(solver.committees, expected)

    def test_gives_multiple_optimal_committees1(self):
        solver = self.construct_trivial(3,2,"cc")
        expected = [{0, 1}, {0,2}]
        self.assertEqual(solver.committees, expected)

    def test_gives_multiple_optimal_committees2(self):
        solver = self.construct_trivial(4,2,"pav")
        expected = [{0, 1}, {0,2}, {0,3}]
        self.assertEqual(solver.committees, expected)

    def test_graph_trivial(self):
        solver = self.construct_trivial(1, 1, "cc")
        self.assertEqual(len(solver.G.nodes),  1)
        self.assertEqual(solver.G.nodes[0][COMMITTEE],  {0})

    def test_graph_non_trivial(self):
        solver = self.construct_trivial(3, 2, "cc")
        self.assertEqual(len(solver.G.nodes), 2)
        self.assertTrue(solver.G.has_edge(0,1))

    def test_pav_easy_case(self):
        solver = self.construct_slightly_nontrivial(4, 2, "pav")
        expected = [{0, 1}, {0,2}]
        self.assertCountEqual(expected, solver.committees)

    def test_cc_easy_case(self):
        solver = self.construct_slightly_nontrivial(4, 2, "cc")
        expected = [{0, 1}, {0,2}, {1, 2}, {0, 3}]
        self.assertCountEqual(expected, solver.committees)
        self.assertEqual(1, solver.get_connected_components_nro())

    def test_pav_decrease_optimality(self):
        solver = self.construct_slightly_nontrivial(4, 2, "pav", opt_req=0.66)
        expected = [{0, 1}, {0, 2}, {1, 2}, {0, 3}]
        self.assertCountEqual(expected, solver.committees)
        self.assertEqual(1, solver.get_connected_components_nro())

    def test_cc_decrease_optimality(self):
        solver = self.construct_slightly_nontrivial(5, 2, "cc", opt_req=0.5)
        expected = [{0, 1}, {0, 2}, {1, 2}, {0, 3}, {1,3}, {2,3}, {0, 4}, {1,4}, {2,4}]
        self.assertCountEqual(expected, solver.committees)
        self.assertEqual(1, solver.get_connected_components_nro())

    def test_disconnected_committees(self):
        solver, renames = self.read_test_2()
        self.assertEqual(2, solver.get_connected_components_nro())
        self.assertEqual(2, len(solver.G.nodes))
        expected_committees = [{16384, 2862, 7745}, {16384, 16879, 9340}]
        actual_committees = [{renames[y] for y in x} for x in solver.get_winning_committees()]
        self.assertEqual(expected_committees, actual_committees)

    def test_make_disconnected_committees_connected(self):
        solver, renames = self.read_test_2(opt_req=0.991)
        self.assertEqual(1, solver.get_connected_components_nro())
        self.assertEqual(4, len(solver.G.nodes))
        self.assertEqual(3, solver.get_path_length((0,2,6), (3,6,7)))
        self.assertEqual(0, solver.compute_all_extra_swaps_needed())

    def test_disconnected_committees_pav(self):
        solver, renames = self.read_test_4()
        self.assertEqual(2, solver.get_connected_components_nro())
        self.assertEqual(2, len(solver.G.nodes))
        expected_committees = [{0,1,2,5}, {0,3,4,5}]
        self.assertEqual(expected_committees, solver.get_winning_committees())

    def test_make_disconnected_committees_connected2(self):
        solver, renames = self.read_test_4(opt_req=0.996)
        self.assertEqual(1, solver.get_connected_components_nro())
        self.assertEqual(4, len(solver.G.nodes))

    def test_fixed_alternatives_works_pav(self):
        voters = [{1, 2}, {2, 0, 3}, {1}]
        min_score = 0
        rule = "pav"
        m = 4
        k = 2
        winning_committees = get_all_winning_committees(m, voters, rule, k, min_score=min_score, fixed_alts=[0])
        self.assertCountEqual(winning_committees, [{0,1}, {0,2}, {0,3}])

    def test_fixed_alternatives_works_cc(self):
        voters = [{1, 2}, {2, 0, 3}, {1}]
        min_score = 0
        rule = "cc"
        m = 4
        k = 2
        winning_committees = get_all_winning_committees(m, voters, rule, k, min_score=min_score, fixed_alts=[0])
        self.assertCountEqual(winning_committees, [{0,1}, {0,2}, {0,3}])

    def test_min_score_works_pav(self):
        voters = [{1, 2}, {2, 0, 3}, {1}]
        min_score = 3
        rule = "pav"
        m = 4
        k = 2
        winning_committees = get_all_winning_committees(m, voters, rule, k, min_score=min_score)
        self.assertCountEqual(winning_committees, [{0,1}, {1,2}, {1,3}])

    def test_min_score_works_cc(self):
        voters = [{1, 2}, {2, 0, 3}, {1}]
        min_score = 3
        rule = "cc"
        m = 4
        k = 2
        winning_committees = get_all_winning_committees(m, voters, rule, k, min_score=min_score)
        self.assertCountEqual(winning_committees, [{0,1}, {1,2}, {1,3}])

    def test_min_score_and_fixed_works_cc(self):
        voters = [{1, 2}, {2, 0, 3}, {1}]
        min_score = 3
        rule = "cc"
        m = 4
        k = 2
        winning_committees = get_all_winning_committees(m, voters, rule, k, min_score=min_score, fixed_alts=[3])
        self.assertCountEqual(winning_committees, [{1,3}])

    def test_heuristic_cc_no_path(self):
        voters = [{0, 1}, {1, 2}, {2, 3}, {0, 3}, {0, 1}, {1, 2}, {2, 3}, {0, 3}, {4}]
        rule = "cc"
        m = 5
        k = 2
        init_C = {0,2}
        goal_C = {1,3}
        path = graph_union_heuristic(voters, rule, m, init_C, goal_C)
        self.assertIsNone(path)

    def test_heuristic_cc_has_path(self):
        voters = [{0, 1}, {1, 2}, {2, 3}, {0, 3}, {0, 1}, {1, 2}, {2, 3}, {0, 3}, {4}]
        rule = "cc"
        m = 5
        k = 2
        init_C = {0,2}
        goal_C = {1,3}
        path = graph_union_heuristic(voters, rule, m, init_C, goal_C, delta_s=2)
        possible_mids = [{0, 1}, {1, 2}, {2, 3}, {0, 3}, {0, 1}, {1, 2}, {2, 3}, {0, 3}]
        self.assertIn(path[1], possible_mids)
        self.assertEqual(len(path), 3)

    def test_heuristic_cc_has_path_through_middle(self):
        voters = [{0, 1}, {1, 2}, {2, 3}, {0, 3}, {4}, {4}, {4}, {4}, {4}, {4}]
        rule = "cc"
        m = 5
        k = 2
        init_C = {0,2}
        goal_C = {1,3}
        path = graph_union_heuristic(voters, rule, m, init_C, goal_C, delta_s=0)
        print(path)
        expected_paths = [[{0,2}, {2,4}, {1,4}, {1,3}],
                          [{0,2}, {2,4}, {3,4}, {1,3}],
                          [{0,2}, {0,4}, {1,4}, {1,3}],
                          [{0,2}, {0,4}, {3,4}, {1,3}]]
        self.assertIn(path, expected_paths)

    def test_heuristic_pav_has_path_through_middle(self):
        voters = [{0, 1}, {1, 2}, {2, 3}, {0, 3}, {4}, {4}, {4}, {4}, {4}, {4}]
        rule = "pav"
        m = 5
        k = 2
        init_C = {0,2}
        goal_C = {1,3}
        path = graph_union_heuristic(voters, rule, m, init_C, goal_C, delta_s=0)
        print(path)
        expected_paths = [[{0,2}, {2,4}, {1,4}, {1,3}],
                          [{0,2}, {2,4}, {3,4}, {1,3}],
                          [{0,2}, {0,4}, {1,4}, {1,3}],
                          [{0,2}, {0,4}, {3,4}, {1,3}]]
        self.assertIn(path, expected_paths)

    def test_heuristic_uses_middle_alts_first(self):
        voters = [{0, 1}, {1, 2}, {2, 3}, {0, 3}, {4}, {4}, {4}, {4}, {4}, {4}]
        rule = "pav"
        m = 5
        k = 2
        init_C = {0,2}
        goal_C = {1,3}
        path = graph_union_heuristic(voters, rule, m, init_C, goal_C, delta_s=0.5)
        possible_mids = [{0, 1}, {1, 2}, {2, 3}, {0, 3}, {0, 1}, {1, 2}, {2, 3}, {0, 3}]
        self.assertIn(path[1], possible_mids)
        self.assertEqual(len(path), 3)

    def test_brute_force_cc_no_path(self):
        voters = [{0, 1}, {1, 2}, {2, 3}, {0, 3}, {0, 1}, {1, 2}, {2, 3}, {0, 3}, {4}]
        rule = "cc"
        m = 5
        k = 2
        init_C = {0, 2}
        goal_C = {1, 3}
        path = brute_force_solve(voters, rule, m, init_C, goal_C)
        self.assertIsNone(path)

    def test_brute_force_cc_has_path(self):
        voters = [{0, 1}, {1, 2}, {2, 3}, {0, 3}, {0, 1}, {1, 2}, {2, 3}, {0, 3}, {4}]
        rule = "cc"
        m = 5
        k = 2
        init_C = {0, 2}
        goal_C = {1, 3}
        path = brute_force_solve(voters, rule, m, init_C, goal_C, delta_s=2)
        possible_mids = [{0, 1}, {1, 2}, {2, 3}, {0, 3}, {0, 1}, {1, 2}, {2, 3}, {0, 3}]
        self.assertIn(path[1], possible_mids)
        self.assertEqual(len(path), 3)

    def test_brute_force_cc_has_path_through_middle(self):
        voters = [{0, 1}, {1, 2}, {2, 3}, {0, 3}, {4}, {4}, {4}, {4}, {4}, {4}]
        rule = "cc"
        m = 5
        k = 2
        init_C = {0, 2}
        goal_C = {1, 3}
        path = brute_force_solve(voters, rule, m, init_C, goal_C, delta_s=0)
        print(path)
        expected_paths = [[{0, 2}, {2, 4}, {1, 4}, {1, 3}],
                          [{0, 2}, {2, 4}, {3, 4}, {1, 3}],
                          [{0, 2}, {0, 4}, {1, 4}, {1, 3}],
                          [{0, 2}, {0, 4}, {3, 4}, {1, 3}]]
        self.assertIn(path, expected_paths)

    def test_brute_force_pav_has_path_through_middle(self):
        voters = [{0, 1}, {1, 2}, {2, 3}, {0, 3}, {4}, {4}, {4}, {4}, {4}, {4}]
        rule = "pav"
        m = 5
        init_C = {0, 2}
        goal_C = {1, 3}
        path = brute_force_solve(voters, rule, m, init_C, goal_C, delta_s=0)
        print(path)
        expected_paths = [[{0, 2}, {2, 4}, {1, 4}, {1, 3}],
                          [{0, 2}, {2, 4}, {3, 4}, {1, 3}],
                          [{0, 2}, {0, 4}, {1, 4}, {1, 3}],
                          [{0, 2}, {0, 4}, {3, 4}, {1, 3}]]
        self.assertIn(path, expected_paths)


if __name__ == '__main__':
    unittest.main()
