from abc import ABC, abstractmethod
from functools import lru_cache
from typing import Tuple

from ai_control_games.generic_game import Game, BlueTeamPolicy, RedTeamStrategy
from ai_control_games.generic_optimisation import FrontierExplorer, SingleObjectiveOptimiserOutput
from ai_control_games.other_utils import grid_then_ternary_search
from ai_control_games.trusted_monitor.tm_game import SimpleTMRedTeamPolicy
from ai_control_games.trusted_monitor.tm_blue_team_policy_subtypes import SimpleTMBlueTeamPolicy
from ai_control_games.trusted_monitor.trusted_monitor_calculations import CSM_FLEXI, CUM_FLEXI


class TMFrontierExplorer(FrontierExplorer, ABC):
    def __init__(
            self,
            red_response_grid_res: int = 32,
            should_track_s_and_u_pairs: bool = False,
            blue_br_tol: float = 1e-6,
            red_br_tol: float = 1e-6,
            equilibrium_tol: float = 1e-4
    ):
        self.blue_br_tol = blue_br_tol
        self.red_br_tol = red_br_tol
        self.equilibrium_tol = equilibrium_tol
        self.red_response_grid_res = red_response_grid_res
        self.should_track_s_and_u_pairs = should_track_s_and_u_pairs
        self.s_and_u_pairs = set()

    @abstractmethod
    def get_single_obj_equilibrium(self, game: Game, weight: float) -> SingleObjectiveOptimiserOutput:
        pass

    def get_safety_and_red_best_response(self, game: Game, blue_pol: BlueTeamPolicy) -> Tuple[float, RedTeamStrategy]:
        s, red_response = TMFrontierExplorer._get_safety_and_red_best_response_inner(
            game, blue_pol,
            red_br_tol=self.red_br_tol,
            red_response_grid_res=self.red_response_grid_res
        )
        if self.should_track_s_and_u_pairs:
            self.s_and_u_pairs.add((s, self.get_usefulness(game, blue_pol)))
        return s, red_response

    @staticmethod
    @lru_cache(maxsize=None)
    def _get_safety_and_red_best_response_inner(
            game: Game, blue_pol: BlueTeamPolicy,
            red_br_tol: float,
            red_response_grid_res: int
    ) -> Tuple[float, RedTeamStrategy]:
        def loss(_beta):
            return CSM_FLEXI.calculate_safety(game, blue_pol, rts=SimpleTMRedTeamPolicy(beta=_beta))

        beta, safety = grid_then_ternary_search(
            f=loss, left=0.0, right=1.0, f_options=dict(),
            epsilon=red_br_tol,  # If grid gets us in the basin of glob_opt, ternary search gets us to red_br_tol
            grid_res=red_response_grid_res
        )
        return safety, SimpleTMRedTeamPolicy(beta=beta)

    def get_safety(self, game, blue_pol, red_pol):
        return CSM_FLEXI.calculate_safety(game, blue_pol, red_pol)

    def get_usefulness(self, game, blue_pol):
        return CUM_FLEXI.calculate_usefulness(game, blue_pol)

    def get_initial_red_team_strategy(self, game):
        return SimpleTMRedTeamPolicy(beta=0.5)

    def get_initial_blue_team_policy(self, game: Game) -> BlueTeamPolicy:
        audit_prop = game.gr.audit_proportion
        safe_dist = game.ma.safe_dist
        return SimpleTMBlueTeamPolicy(q_audit=audit_prop, q_defer=min(0.5,2*audit_prop), safe_dist=safe_dist)
