import itertools
from abc import ABC

import numpy as np
from tqdm import tqdm
import sys

from ai_control_games.generic_game import Game, RedTeamStrategy, BlueTeamPolicy
from ai_control_games.other_utils import grid_then_ternary_search

from ai_control_games.trusted_monitor.optimization.abstract_explorer import TMFrontierExplorer
from ai_control_games.trusted_monitor.tm_blue_team_policy_subtypes import SimpleTMBlueTeamPolicy, \
    VariableAuditFixedDefer
from ai_control_games.trusted_monitor.tm_game import TrustedMonitoringGame, SimpleTMRedTeamPolicy
from ai_control_games.trusted_monitor.trusted_monitor_calculations import CSM_FLEXI, CUM_FLEXI


class TMSearch2DSimple(TMFrontierExplorer, ABC):
    def __init__(
            self,
            grid_res_per_zoom: int = 21,
            **kwargs
    ):
        assert grid_res_per_zoom > 3, "Need at least 4 points to avoid infinite loop in ternary search"
        self.grid_res_per_zoom = grid_res_per_zoom
        super().__init__(**kwargs)

    def get_value_and_blue_best_response(
            self,
            game: TrustedMonitoringGame,
            weight,
            red_team_strategy: SimpleTMRedTeamPolicy,
            lower_bounds=(0.0, 0.0),
            upper_bounds=(1.0, 1.0),
    ) -> tuple[float, BlueTeamPolicy]:

        def loss(_q_a, _q_d):
            _bts = SimpleTMBlueTeamPolicy(q_audit=_q_a, q_defer=_q_d, safe_dist=game.ma.safe_dist)
            _s = CSM_FLEXI.calculate_safety(game, _bts, red_team_strategy)
            _u = CUM_FLEXI.calculate_usefulness(game, _bts)
            return -1 * ((weight * _s) + ((1 - weight) * _u))

        q_as = np.linspace(lower_bounds[0], upper_bounds[0], self.grid_res_per_zoom)
        q_ds = np.linspace(lower_bounds[1], upper_bounds[1], self.grid_res_per_zoom)
        min_so_far = float("inf")
        argmin_ind_so_far = None
        pairs = list(itertools.product(range(len(q_as)), range(len(q_ds))))
        diff = [(u - l) for l, u in zip(lower_bounds, upper_bounds)]
        for q_a_ind, q_d_ind in tqdm(pairs, delay=5, desc=f"Searching audit and defer FPRs {diff}", leave=False):
            q_a, q_d = q_as[q_a_ind], q_ds[q_d_ind]
            l = loss(q_a, q_d)
            if l < min_so_far:
                min_so_far = l
                argmin_ind_so_far = (q_a_ind, q_d_ind)
        # =====
        # . . U  -> top right of square
        # . * .  -> the best we found
        # L . .  -> bottom left of square
        # =====
        best_q_a_ind = argmin_ind_so_far[0]
        best_q_d_ind = argmin_ind_so_far[1]
        lower = (max(0, best_q_a_ind - 1), max(0, best_q_d_ind - 1))
        upper = (min(len(q_as) - 1, best_q_a_ind + 1), min(len(q_ds) - 1, best_q_d_ind + 1))
        q_a_diff = q_as[upper[0]] - q_as[lower[0]]
        q_d_diff = q_ds[upper[1]] - q_ds[lower[1]]
        if q_a_diff < self.blue_br_tol and q_d_diff < self.blue_br_tol:
            bts = SimpleTMBlueTeamPolicy(q_audit=q_as[best_q_a_ind], q_defer=q_ds[best_q_d_ind],
                                         safe_dist=game.ma.safe_dist)
            safety = CSM_FLEXI.calculate_safety(game, bts, red_team_strategy)
            usefulness = CUM_FLEXI.calculate_usefulness(game, bts)
            value = (weight * safety) + ((1 - weight) * usefulness)
            return value, bts
        else:
            return self.get_value_and_blue_best_response(
                game=game,
                weight=weight,
                red_team_strategy=red_team_strategy,
                lower_bounds=(q_as[lower[0]], q_ds[lower[1]]),
                upper_bounds=(q_as[upper[0]], q_ds[upper[1]])
            )


class TMFixedAudit(TMFrontierExplorer, ABC):
    def __init__(self, q_audit: float, grid_res: int = 32, *args, **kwargs):
        self.q_audit = q_audit
        self.grid_res = grid_res
        super().__init__(*args, **kwargs)

    def get_value_and_blue_best_response(self, game: Game, weight: float, red_pol: RedTeamStrategy) -> \
            tuple[float, BlueTeamPolicy]:
        assert red_pol is not None

        def loss(qd):
            bts = SimpleTMBlueTeamPolicy(q_audit=self.q_audit, q_defer=qd, safe_dist=game.ma.safe_dist)
            _s = CSM_FLEXI.calculate_safety(game, bts, red_pol)
            _u = CUM_FLEXI.calculate_usefulness(game, bts)
            return -1 * ((weight * _s) + ((1 - weight) * _u))

        q_defer, v = grid_then_ternary_search(
            f=loss,
            left=0.0,
            right=1.0,
            epsilon=self.blue_br_tol,  # If grid gets us in the basin of glob_opt, ternary search gets us to blue_br_tol
            grid_res=self.grid_res,
            f_options=dict()
        )
        bts = SimpleTMBlueTeamPolicy(q_audit=self.q_audit, q_defer=q_defer, safe_dist=game.ma.safe_dist)
        safety = CSM_FLEXI.calculate_safety(game, bts, red_pol)
        usefulness = CUM_FLEXI.calculate_usefulness(game, bts)
        value = (weight * safety) + ((1 - weight) * usefulness)
        return value, bts
