import unittest
import logging
import random
from main import Game, Configuration
from array import array

class Test(unittest.TestCase):

    def test_wmax(self):
        for n in [3, 4]:
            for m in [2, 5]:
                game = Game.generate_parallel_link_model(n, m, "debug_fixed")
                w_max = game.w_max()
                self.assertTrue(w_max == n)

    def test_quit_after_first_NE(self):
        for n in [3, 4]:
            for m in [2, 5]:
                for util in ["debug_fixed"]:
                    game = Game.generate_parallel_link_model(n, m, util)
                    (dpo_num_pure_ne_configurations, _) = game.dynamic_program_set_based(True, False)
                    (dpu_num_pure_ne_configurations, _) = game.dynamic_program_table_based(True, False)
                    self.assertTrue(dpo_num_pure_ne_configurations == dpu_num_pure_ne_configurations)
                    (dpo_with_memory_num_pure_ne_configurations, _) = game.dynamic_program_set_based(True, True)
                    (dpu_with_memory_num_pure_ne_configurations, _) = game.dynamic_program_table_based(True, True)
                    (bf_num_pure_ne, bf_all_pure_ne) = game.brute_force(True, True)
                    self.assertTrue(dpo_num_pure_ne_configurations == 1)
                    self.assertTrue(dpo_with_memory_num_pure_ne_configurations == 1)
                    self.assertTrue(dpu_with_memory_num_pure_ne_configurations == 1)
                    self.assertTrue(bf_num_pure_ne == 1)
                    self.assertTrue(bf_all_pure_ne == [array('L', [0 for _ in range(game.n)])])

    # Test to ensure that algorithms enumerate entire
    #   strategy/aggregate space and test that 
    #   algorithms share the same NE.
    def test_every_strategy_profile_a_NE(self):
        for n in [3, 4]:
            for m in [2, 5]:
                for util in ["debug_fixed"]:
                    game = Game.generate_parallel_link_model(n, m, util)
                    (dpo_num_pure_ne_configurations, dpo_all_pure_ne) = game.dynamic_program_set_based(False, True)
                    (dpu_num_pure_ne_configurations, dpu_all_pure_ne) = game.dynamic_program_table_based(False, True)
                    (bf_num_pure_ne, bf_all_pure_ne) = game.brute_force(False, True)
                    bf_all_pure_ne.sort()
                    (all_configurations, _, _) = game.generate_all_configurations()
                    if util == "debug_fixed":
                        self.assertTrue(dpo_num_pure_ne_configurations == len(all_configurations))
                        self.assertTrue(dpu_num_pure_ne_configurations == len(all_configurations))
                        self.assertTrue(bf_num_pure_ne == pow(m, n))
                    self.assertTrue(len(dpo_all_pure_ne) == bf_num_pure_ne)
                    self.assertTrue(dpo_all_pure_ne == bf_all_pure_ne)
                    self.assertTrue(dpu_all_pure_ne == bf_all_pure_ne)

    def test_generate_kdim_parameters(self):
        (alpha, beta, z, _) = Game.generate_kdim_parameters(None, 1, 2)
        self.assertTrue(alpha == [])
        self.assertTrue(beta == [])
        self.assertTrue(z == [])

    def test_generate_demand_vectors(self):
        demand = Game.generate_demand_vectors(None, n=5, k=2)
        self.assertTrue(demand == [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1]])

        random.seed(20)
        demand = Game.generate_demand_vectors(("random_int", 1, 3), n=5, k=2)
        self.assertTrue(demand == [[3, 3], [1, 2], [3, 3], [1, 2], [3, 1]])


    def test_utility_kD_non_monotonic(self):
        log= logging.getLogger( "Test.test_utility_kD_non_monotonic" )
        # 1D case
        m = 3
        k = 1
        w_ceiling = 10
        L = 100
        configuration = Configuration((1, 2, 9))
        game = Game.generate_parallel_link_model(7, m, "kD non-monotonic", k=k, kdim_cost_parameters=("seed_non-monotonic_int", L, w_ceiling), demand_parameters=("random_int", 0, 5), seed=33)
        random.seed(33)
        self.assertTrue(game.alpha == [random.randint(0, L) for _ in range(m)])
        self.assertTrue(game.beta == [random.randint(0, L) for _ in range(m)])
        self.assertTrue(game.z == [[[random.randint(0, L) for _ in range(w_ceiling)] for _ in range(k)] for _ in range(m)])
        self.assertTrue(configuration.utility(game, 0, game.actions[0][0]) == -4993)
        self.assertTrue(configuration.utility(game, 0, game.actions[0][1]) == -791)
        self.assertTrue(configuration.utility(game, 0, game.actions[0][2]) == -5741)
        self.assertTrue(configuration.utility(game, 2, game.actions[2][0]) == -4993)
        self.assertTrue(configuration.utility(game, 2, game.actions[2][1]) == -791)
        self.assertTrue(configuration.utility(game, 2, game.actions[2][2]) == -5741)

        # 2D case
        m = 3
        k = 2
        w_ceiling = 10
        L = 100
        configuration = Configuration((1, 2, 9, 9, 0, 0))
        game = Game.generate_parallel_link_model(7, m, "kD non-monotonic", k=k, kdim_cost_parameters=("seed_non-monotonic_int", L, w_ceiling), demand_parameters=("random_int", 0, 5), seed=42)
        random.seed(42)
        self.assertTrue(game.alpha == [random.randint(0, L) for _ in range(m)])
        self.assertTrue(game.beta == [random.randint(0, L) for _ in range(m)])
        self.assertTrue(game.z == [[[random.randint(0, L) for _ in range(w_ceiling)] for _ in range(k)] for _ in range(m)])
        self.assertTrue(configuration.utility(game, 0, game.actions[0][0]) == -2362)
        self.assertTrue(configuration.utility(game, 0, game.actions[0][1]) == -637)
        self.assertTrue(configuration.utility(game, 0, game.actions[0][2]) == -244)
        self.assertTrue(configuration.utility(game, 2, game.actions[2][0]) == -2362)
        self.assertTrue(configuration.utility(game, 2, game.actions[2][1]) == -637)
        self.assertTrue(configuration.utility(game, 2, game.actions[2][2]) == -244)

if __name__ == '__main__':
    logging.basicConfig(level=logging.DEBUG)
    runner = unittest.TextTestRunner(verbosity=2)

    # Run all tests
    unittest.main(testRunner=runner)