from ai_control_games.generic_game import Game, RedTeamStrategy
from ai_control_games.generic_optimisation import SingleObjectiveOptimiserOutput, \
    FrontierExplorerWithGridSearch
from ai_control_games.trusted_monitor.optimization.abstract_explorer import TMFrontierExplorer
from ai_control_games.trusted_monitor.optimization.blue_best_response import TMSearch2DSimple, TMFixedAudit
from ai_control_games.trusted_monitor.optimization.single_objective_eq import TMFrontierExplorerIteratedDet, \
    TMFrontierExplorerIteratedMixed, TMFrontierExplorerDoubleOracle


class TMFrontierExplorerWithGridSearch(FrontierExplorerWithGridSearch, TMFrontierExplorer):
    def __init__(self, blue_team_policies: list, *args, **kwargs):
        FrontierExplorerWithGridSearch.__init__(self, blue_team_policies)
        TMFrontierExplorer.__init__(self, *args, **kwargs)

    def get_initial_red_team_strategy(self, game: Game) -> RedTeamStrategy:
        raise NotImplementedError("Not needed for grid search")

    def get_value_and_blue_best_response(self, game: Game, weight: float, red_pol: RedTeamStrategy):
        raise NotImplementedError("Not needed for grid search")

    def get_single_obj_equilibrium(self, game: Game, weight: float) -> SingleObjectiveOptimiserOutput:
        return FrontierExplorerWithGridSearch.get_single_obj_equilibrium(self, game, weight)


class TMFrontierExplorerIteratedDetFixedAud(TMFixedAudit, TMFrontierExplorerIteratedDet):
    def __init__(self, q_audit, *args, **kwargs):
        TMFixedAudit.__init__(self, q_audit, *args, **kwargs)
        TMFrontierExplorerIteratedDet.__init__(self, *args, **kwargs)


class TMFrontierExplorerIteratedDet2DGridSimple(TMSearch2DSimple, TMFrontierExplorerIteratedDet):
    def __init__(self, *args, **kwargs):
        TMSearch2DSimple.__init__(self, *args, **kwargs)
        TMFrontierExplorerIteratedDet.__init__(self, *args, **kwargs)


class TMFrontierExplorerMixedFixedAud(TMFixedAudit, TMFrontierExplorerIteratedMixed):
    def __init__(self, q_audit, *args, **kwargs):
        TMFixedAudit.__init__(self, q_audit, *args, **kwargs)
        TMFrontierExplorerIteratedMixed.__init__(self, *args, **kwargs)


class TMFrontierExplorerMixed2DGridMixed(TMSearch2DSimple, TMFrontierExplorerIteratedMixed):
    def __init__(self, *args, **kwargs):
        TMSearch2DSimple.__init__(self, *args, **kwargs)
        TMFrontierExplorerIteratedMixed.__init__(self, *args, **kwargs)


class TMFrontierExplorerWithDoubleOracle(TMSearch2DSimple, TMFrontierExplorerDoubleOracle):
    def __init__(self, *args, **kwargs):
        TMSearch2DSimple.__init__(self, *args, **kwargs)
        TMFrontierExplorerDoubleOracle.__init__(self, *args, **kwargs)


TM_OPTIMISERS = [
    TMFrontierExplorerWithGridSearch,
    TMFrontierExplorerIteratedDetFixedAud,
    TMFrontierExplorerIteratedDet2DGridSimple,
    TMFrontierExplorerMixedFixedAud,
    TMFrontierExplorerMixed2DGridMixed,
    TMFrontierExplorerWithDoubleOracle,
]
