import collections

collections.Callable = collections.abc.Callable

from ai_control_games.trusted_monitor.rwr_data import CONT_TM_SF_BARS_READ_OFF_FIG_4A, CONT_TM_BD_BARS_READ_OFF_FIG_4A
from ai_control_games.trusted_monitor.tm_blue_team_policy_subtypes import SimpleTMBlueTeamPolicy

from ai_control_games.other_utils import XXXs_profile
from ai_control_games.distributions import SimplePolyDistOver0to1, NORMALISED_NORMAL_APPROX_OF_RW, UniformDistOver0to1

from ai_control_games.trusted_monitor.tm_game import TMGameRules, SimpleTMRedTeamPolicy, TMModellingAssumptions, \
    TrustedMonitoringGame, \
    TMRedAction, TMBlueObsType, RunAuditOrDeferType
from ai_control_games.trusted_monitor.tm_batch_game_manager import TMBatchGameManager_SimpleBTPolicy, \
    TMBatchGameManager_SimpleBTPolicy_Simple_or_Adaptive_Audit, TMBatchGameManager_MixedBTPolicy
from unittest import TestCase


class TestTM(TestCase):
    def test_tm_red_team_policy(self):
        tm_red_team_policy = SimpleTMRedTeamPolicy(beta=0.5)
        assert tm_red_team_policy.act(None) in [TMRedAction.BD, TMRedAction.DONT]

    def test_tm_blue_team_policy(self):
        tm_blue_team_policy = SimpleTMBlueTeamPolicy(q_audit=0.5, q_defer=0.5, safe_dist=UniformDistOver0to1())
        action = tm_blue_team_policy.act(TMBlueObsType(sus_score=0.6, num_problems_left=10, num_coins_left=5))
        assert action in [RunAuditOrDeferType.AUDIT, RunAuditOrDeferType.DEFER, RunAuditOrDeferType.RUN]

    def test_tm_gr(self):
        tm_gr = TMGameRules(num_problems=100, audit_proportion=0.02)
        assert tm_gr.num_problems == 100
        assert tm_gr.num_coins == 2

        with self.assertRaises(ValueError):
            TMGameRules(num_problems=100, audit_proportion=0.005)

    def test_tm_ma(self):
        dists = [SimplePolyDistOver0to1(power=power) for power in range(1, 4)] + [NORMALISED_NORMAL_APPROX_OF_RW]

        for dist in dists:
            tm_ma = TMModellingAssumptions(bd_dist=dist, safe_dist=UniformDistOver0to1())
            tm_ma.validate_sus_dists()
            assert tm_ma.sample_sus_score(is_bd=True) >= 0.0
            assert tm_ma.sample_sus_score(is_bd=True) <= 1.0
            assert tm_ma.sample_sus_score(is_bd=False) >= 0.0
            assert tm_ma.sample_sus_score(is_bd=False) <= 1.0

    def test_tm_game(self):
        poly_tm_dist = SimplePolyDistOver0to1(power=2)
        tm_ma = TMModellingAssumptions(bd_dist=poly_tm_dist, safe_dist=UniformDistOver0to1())
        g = TrustedMonitoringGame(
            ma=tm_ma,
            gr=TMGameRules(num_problems=100, audit_proportion=0.02)
        )
        red_team = SimpleTMRedTeamPolicy(beta=0.5)
        blue_team = SimpleTMBlueTeamPolicy(q_audit=0.5, q_defer=0.5, safe_dist=UniformDistOver0to1())
        was_backdoored, usefulness, info = g.simulate_one_game(blue_team, red_team)
        assert was_backdoored in [True, False]
        assert usefulness >= 0.0
        assert usefulness <= 1.0

    def test_TM_batch_game_manager(self):
        bgm = TMBatchGameManager_SimpleBTPolicy(should_verify=True)

        def func():
            bgm.get_MC_safety_estimates(num_samples=10)

        XXXs_profile(f=func, num_rows=100)

        bgm.get_MC_usefulness_estimates(num_samples=10)


class TestTM_Calculations_SimpleBTPolicy(TestCase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.NUM_SAMPLES = 200
        self.BGM_CLASS = TMBatchGameManager_SimpleBTPolicy

    def test_CalculateSafety(self):
        bgm = self.BGM_CLASS(should_verify=True, num_problems=10, audit_prop=0.1)

        bgm.get_exact_safety_calculation(should_tqdm=True)

    def test_CalculateUsefulness(self):
        bgm = self.BGM_CLASS(should_verify=True, num_problems=10, audit_prop=0.1)
        bgm.get_exact_usefulness_calculation(should_tqdm=True)

    def test_CalculateSafety_against_MC(self):
        bgm = self.BGM_CLASS(should_verify=True, num_problems=10, audit_prop=0.1)
        bgm = bgm.get_random_subset(num_rows=10)
        bgm.get_MC_safety_estimates(num_samples=self.NUM_SAMPLES)
        bgm.get_exact_safety_calculation(should_tqdm=True)
        bgm.assert_MC_vs_exact()

    def test_CalculateSafety_against_MC_RW_dists(self):
        bgm = self.BGM_CLASS(
            should_verify=True,
            num_problems=100,
            audit_prop=0.02,
            sf_distribution=CONT_TM_SF_BARS_READ_OFF_FIG_4A,
            bd_distribution=CONT_TM_BD_BARS_READ_OFF_FIG_4A
        )
        bgm = bgm.get_random_subset(num_rows=100)
        bgm.get_MC_safety_estimates(num_samples=self.NUM_SAMPLES)
        bgm.get_exact_safety_calculation(should_tqdm=True)
        bgm.assert_MC_vs_exact()

    def test_CalculateUsefulness_against_MC(self):
        bgm = self.BGM_CLASS(should_verify=True, num_problems=10, audit_prop=0.1)
        bgm = bgm.get_random_subset(num_rows=10)
        bgm.get_MC_usefulness_estimates(num_samples=self.NUM_SAMPLES)
        bgm.get_exact_usefulness_calculation(should_tqdm=True)
        bgm.assert_MC_vs_exact()


class TestTM_Calculations_Simple_or_Adaptive(TestTM_Calculations_SimpleBTPolicy):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.BGM_CLASS = TMBatchGameManager_SimpleBTPolicy_Simple_or_Adaptive_Audit
        self.NUM_SAMPLES = 200


class TestTM_Calculations_Mixed(TestTM_Calculations_SimpleBTPolicy):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.BGM_CLASS = TMBatchGameManager_MixedBTPolicy
        self.NUM_SAMPLES = 1000
