from unittest import TestCase

import numpy as np

from ai_control_games.generic_game import BlueTeamPolicy, RedTeamStrategy, MixedBlueTeamPolicy
from ai_control_games.trusted_monitor.optimization.concrete_optimisers import *
from ai_control_games.trusted_monitor.tm_blue_team_policy_subtypes import SimpleTMBlueTeamPolicy
from ai_control_games.trusted_monitor.tm_game import TrustedMonitoringGame, TMModellingAssumptions, TMGameRules


class TestTMOptimisationForSimpleTMBlue(TestCase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.num_points_along_frontier = 33
        from ai_control_games.trusted_monitor.rwr_data import BETA_APPROX_OF_RW_SF, BETA_APPROX_OF_RW_BD
        self.game = TrustedMonitoringGame(
            ma=TMModellingAssumptions(safe_dist=BETA_APPROX_OF_RW_SF, bd_dist=BETA_APPROX_OF_RW_BD),
            gr=TMGameRules(num_problems=10, audit_proportion=0.2)
        )

    def check_frontier_point(self, frontier_points):
        for fp in frontier_points:
            assert 0 <= fp.weight <= 1.0, f"Weight: {fp.weight}"
            assert 0 <= fp.safety <= 1.0, f"Safety: {fp.safety}"
            assert 0 <= fp.usefulness <= 1.0, f"Usefulness: {fp.usefulness}"
            if not isinstance(fp.blue_team_strategy, BlueTeamPolicy):
                if not isinstance(fp.blue_team_strategy, MixedBlueTeamPolicy):
                    assert issubclass(type(fp.blue_team_strategy), BlueTeamPolicy), f"{fp.blue_team_strategy=}"
            assert isinstance(fp.red_team_strategy, RedTeamStrategy), f"Red Team Strategy: {fp.red_team_strategy}"
            assert isinstance(fp.info, dict)

        for fp in sorted(frontier_points, key=lambda x: x.weight):
            print(f"s={fp.safety:.3f}, u={fp.usefulness:.3f}, w={fp.weight:.3f}")

    def test_grid_opt(self):
        btss = [
            SimpleTMBlueTeamPolicy(q_audit=q_audit, q_defer=q_defer, safe_dist=self.game.ma.safe_dist)
            for q_audit in np.linspace(0, 1.0, 41)
            for q_defer in np.linspace(0, 1.0, 41)
        ]

        tm_fe = TMFrontierExplorerWithGridSearch(btss)
        frontier_points = tm_fe.explore_frontier(self.game, num_points_along_frontier=self.num_points_along_frontier)
        self.check_frontier_point(frontier_points)

    def test_iter_det_opt(self):
        tm_fe = TMFrontierExplorerIteratedDetFixedAud(q_audit=0.3)
        frontier_points = tm_fe.explore_frontier(self.game, num_points_along_frontier=self.num_points_along_frontier)
        self.check_frontier_point(frontier_points)

    def test_iter_det_opt_2d(self):
        tm_fe = TMFrontierExplorerIteratedDet2DGridSimple()
        frontier_points = tm_fe.explore_frontier(self.game, num_points_along_frontier=self.num_points_along_frontier)
        self.check_frontier_point(frontier_points)

    def test_iter_mixed_fixed_aud(self):
        tm_fe = TMFrontierExplorerMixedFixedAud(q_audit=0.3)
        frontier_points = tm_fe.explore_frontier(self.game, num_points_along_frontier=self.num_points_along_frontier)
        self.check_frontier_point(frontier_points)

    def test_iter_mixed_2d_grid(self):
        tm_fe = TMFrontierExplorerMixed2DGridMixed()
        frontier_points = tm_fe.explore_frontier(self.game, num_points_along_frontier=self.num_points_along_frontier)
        self.check_frontier_point(frontier_points)

    def test_double_oracle(self):
        tm_do = TMFrontierExplorerWithDoubleOracle()
        frontier_points = tm_do.explore_frontier(self.game, num_points_along_frontier=self.num_points_along_frontier)
        self.check_frontier_point(frontier_points)
