from functools import lru_cache
from ai_control_games.trusted_monitor.tm_game import TrustedMonitoringGame, SimpleTMRedTeamPolicy, TMRedTeamPolicy, \
    TMBlueTeamPolicy
from ai_control_games.trusted_monitor.tm_blue_team_policy_subtypes import DeterministicThresholdBTS, \
    SimpleTMBlueTeamPolicy
from ai_control_games.generic_game import MixedBlueTeamPolicy
from abc import ABC, abstractmethod
from typing import Union

import sys

from ai_control_games.other_utils import *

HAS_WARNED_ALREADY = False


def maybe_warn(message: str):
    global HAS_WARNED_ALREADY
    if not HAS_WARNED_ALREADY:
        print(message)
        HAS_WARNED_ALREADY = True


# TODO set up a nice way to test calculations against eachother
# TODO generalise absract class to work for all games


class CalculateSafetyManager(ABC):
    _accepted_BTS_Class = TMBlueTeamPolicy
    _accepted_RTS_Class = TMRedTeamPolicy

    @abstractmethod
    def _calculate_safety_inner(
            self,
            game: TrustedMonitoringGame,
            bts: _accepted_BTS_Class,
            rts: _accepted_RTS_Class,
    ) -> float:
        raise NotImplementedError

    def calculate_safety(
            self,
            game: TrustedMonitoringGame,
            bts: _accepted_BTS_Class,
            rts: _accepted_RTS_Class,
            should_verify=True,
    ) -> float:
        if not isinstance(bts, self._accepted_BTS_Class):
            raise TypeError(f"{bts=} does not match {self._accepted_BTS_Class=} for {self}")
        if not isinstance(rts, self._accepted_RTS_Class):
            raise TypeError(f"{rts=} does not match {self._accepted_RTS_Class=} for {self}")

        res = self._calculate_safety_inner(game, bts, rts)
        if should_verify:
            # TODO fix the weird pydantic bug that comes in here
            assert 0.0 <= res <= 1.0, (res, bts, rts)
            if isinstance(rts, SimpleTMRedTeamPolicy):
                assert rts.beta > 0.0 or floats_are_close(res, 1.0), (rts.beta, res)
            if isinstance(bts, SimpleTMBlueTeamPolicy):
                assert bts.q_defer < 1 or floats_are_close(res, 1.0) \
                    or game.ma.human_FPR > 0 or game.ma.human_FNR > 0, (bts.q_defer, res, bts)
        return res


class CalculateUsefulnessManager(ABC):
    _accepted_BTS_Class = Union[TMBlueTeamPolicy, MixedBlueTeamPolicy]

    @abstractmethod
    def _calculate_usefulness_inner(
            self,
            game: TrustedMonitoringGame,
            bts: TMBlueTeamPolicy,
    ) -> float:
        raise NotImplementedError

    def calculate_usefulness(
            self,
            game: TrustedMonitoringGame,
            bts: TMBlueTeamPolicy,
            should_verify=True,
    ) -> float:
        assert isinstance(bts, self._accepted_BTS_Class), (self, bts, self._accepted_BTS_Class)
        res = self._calculate_usefulness_inner(game, bts)

        if res > 1.0:
            if floats_are_close(res, 1.0):
                res = 1.0
                maybe_warn("Warning: Usefulness > 1.0, but it's close!")
            else:
                raise ValueError(f"Usefulness > 1.0: {res}")

        if should_verify:
            assert isinstance(game, TrustedMonitoringGame)
            assert isinstance(bts, self._accepted_BTS_Class), bts
            assert isinstance(res, float)
            assert 0.0 <= res <= 1.0, (res, bts)
            if isinstance(bts, SimpleTMBlueTeamPolicy):
                assert bts.q_defer > 0 or floats_are_close(res, 1.0) or game.ma.human_FPR > 0, res
        return res


class CalcSafetyManagerFlexibleDefault(CalculateSafetyManager):
    _accepted_BTS_Class = Union[TMBlueTeamPolicy, MixedBlueTeamPolicy]
    _accepted_RTS_Class = Union[SimpleTMRedTeamPolicy, MixedRedTeamPolicy]

    def __init__(self):
        # In order of priority
        self._calc_safety_managers = {
            (SimpleTMBlueTeamPolicy, SimpleTMRedTeamPolicy):
                CalculateSafetyViaRecurrenceForSimpleBTS(),
            (SimpleTMBlueTeamPolicy, MixedRedTeamPolicy):
                CalculateSafetyViaRecurrenceForSimpleBTSMixedRTS(),
            (DeterministicThresholdBTS, SimpleTMRedTeamPolicy):
                CalculateSafetyViaRecurrenceForDeterministicThresholdBTS(),
            (MixedBlueTeamPolicy, SimpleTMRedTeamPolicy):
                CalculateSafetyViaRecurrenceForMixedBTS(),
        }

    def _calculate_safety_inner(
            self,
            game: TrustedMonitoringGame,
            bts: _accepted_BTS_Class,
            rts: _accepted_RTS_Class,
    ) -> float:
        for (bts_class, rts_class), func in self._calc_safety_managers.items():
            if isinstance(bts, bts_class) and isinstance(rts, rts_class):
                return func.calculate_safety(game, bts, rts)
        raise ValueError(f"Unsupported bts and rts combination: {bts}, {rts}")


class EstimateSafetyViaMC(CalculateSafetyManager):
    _accepted_BTS_Class = Union[TMBlueTeamPolicy, MixedBlueTeamPolicy]
    _accepted_RTS_Class = SimpleTMRedTeamPolicy

    _sample_size: int = int(1e3)

    def _calculate_safety_inner(
            self,
            game: TrustedMonitoringGame,
            bts: _accepted_BTS_Class,
            rts: SimpleTMRedTeamPolicy,
    ) -> float:
        est_safety, _ = game.estimate_MC_safety(self._sample_size, bts, rts)
        est_sample_variance = est_safety * (1 - est_safety) / self._sample_size
        # print(f"Estimated safety: {est_safety:.3f} ± {np.sqrt(est_sample_variance):.3f}")
        return est_safety


class EstimateUsefulnessViaMC(CalculateUsefulnessManager):
    _accepted_BTS_Class = Union[TMBlueTeamPolicy, MixedBlueTeamPolicy]

    _sample_size: int = int(1e3)

    def _calculate_usefulness_inner(
            self,
            game: TrustedMonitoringGame,
            bts: _accepted_BTS_Class,
    ) -> float:
        est_usefulness, _ = game.estimate_MC_usefulness(self._sample_size, bts)
        est_sample_variance = est_usefulness * (1 - est_usefulness) / self._sample_size
        # Suppose the sample mean was the true population mean,
        # then the sample mean's variance would be given by:...
        # print(f"Estimated usefulness: {est_usefulness:.3f} ± {np.sqrt(est_sample_variance):.3f}")
        return est_usefulness


class CalcUsefulnessManagerFlexibleDefault(CalculateUsefulnessManager):
    _accepted_BTS_Class = Union[
        SimpleTMBlueTeamPolicy, DeterministicThresholdBTS, MixedBlueTeamPolicy]

    def __init__(self):
        # In order of priority
        self._calc_usefulness_managers = {
            # SimpleTMBlueTeamPolicy: CalculateUsefulnessViaRecurrenceForSimpleBTS(),
            DeterministicThresholdBTS: CalculateUsefulnessViaRecurrenceForDeterministicThresholdBTS(),
            MixedBlueTeamPolicy: CalculateUsefulnessViaRecurrenceForMixedBTS(),
        }

    def _calculate_usefulness_inner(
            self,
            game: TrustedMonitoringGame,
            bts: _accepted_BTS_Class,
    ) -> float:
        for k, func in self._calc_usefulness_managers.items():
            if isinstance(bts, k):
                return func.calculate_usefulness(game, bts)
        # If still running
        raise ValueError(f"Unknown blue team strategy: {bts}")


class CalculateSafetyViaRecurrenceForDeterministicThresholdBTS(CalculateSafetyManager):
    _accepted_BTS_Class = DeterministicThresholdBTS
    _accepted_RTS_Class = SimpleTMRedTeamPolicy

    @staticmethod
    def _get_lambdas(
            game: TrustedMonitoringGame,
            bts: DeterministicThresholdBTS,
            m: int,
            c: int
    ) -> tuple:
        """
        Calculate the lambdas for the recurrence relation for the safety of the blue team.
        Note that for Deterministic Threshold, q_a and q_d can depend on m and c
        λ1 = P[D|B] - P[~A|~B]
            = (f(q_d) - f(q_a)) - (1-q_a) if q_d >= q_a
            = 0.0 - (1-q_a) if q_d < q_a
        λ2 = P[~A|~B] = 1-q_a
        λ3 = P[acc A|B] = f(q_a) * (1 - human_FNR)
        λ4 = P[inacc A|~B] = q_a * human_FPR
        λ5 = P[acc A|~B] = q_a * (1-human_FPR)
        λ6 = P[D|B, c=0] = f(q_d)
        :param game:
        :param bts:
        :return:
        """
        q_a = bts.get_q_audit(m, c)
        q_d = bts.get_q_defer(m, c)
        f_of_q_a = game.ma.get_prob_bd_exceeds_q(q_a)
        f_of_q_d = game.ma.get_prob_bd_exceeds_q(q_d)
        human_FNR = game.ma.get_prob_audit_breach()
        human_FPR = game.ma.get_prob_audit_false_alarm()

        # Run with p = 1 - q_audit
        p_A_gvn_B = f_of_q_a
        if q_a >= q_d:
            p_D_gvn_B = 0.0
        else:
            p_D_gvn_B = f_of_q_d - f_of_q_a

        p_acc_A_gvn_B = p_A_gvn_B * (1 - human_FNR)

        # p_A_gvn_nB = q_a
        p_A_gvn_nB = game.ma.get_prob_safe_exceeds_q(q_a)
        if q_a >= q_d:
            p_A_or_D_gvn_nB = game.ma.get_prob_safe_exceeds_q(q_a)
        else:
            p_A_or_D_gvn_nB = game.ma.get_prob_safe_exceeds_q(q_d)
        p_D_gvn_nB = p_A_or_D_gvn_nB - p_A_gvn_nB
        p_Cont_gvn_nB = 1 - p_A_or_D_gvn_nB

        p_inacc_A_gvn_nB = p_A_gvn_nB * human_FPR
        p_acc_A_gvn_nB = p_A_gvn_nB * (1 - human_FPR)

        lambda_1 = p_D_gvn_B - p_D_gvn_nB - p_Cont_gvn_nB
        lambda_2 = p_D_gvn_nB + p_Cont_gvn_nB
        lambda_3 = p_acc_A_gvn_B
        lambda_4 = p_inacc_A_gvn_nB
        lambda_5 = p_acc_A_gvn_nB

        # Finally deal with cases where there are no coins left
        prob_D_given_B_and_no_coins = f_of_q_d
        lambda_6 = prob_D_given_B_and_no_coins
        return lambda_1, lambda_2, lambda_3, lambda_4, lambda_5, lambda_6

    def _calculate_safety_inner(
            self,
            game: TrustedMonitoringGame,
            bts: DeterministicThresholdBTS,
            rts: SimpleTMRedTeamPolicy,
    ) -> float:
        assert isinstance(bts, DeterministicThresholdBTS)

        C = game.gr.num_coins
        M = game.gr.num_problems

        @lru_cache(maxsize=None)
        def w(
                t: int,  # Time left in number of rounds (0->finished)
                c: int,  # Coins left (each audit costs one coin)
        ) -> float:
            # If the game is over return 1
            if t == 0:
                return 1

            # recalculate lambdas for this game state
            beta = rts.beta
            lambda_1, lambda_2, lambda_3, lambda_4, lambda_5, lambda_6 = self._get_lambdas(game, bts, t, c)
            # w(t,c) = (λ₁ β + λ₂) w(t-1,c) + λ₅ (1 - β) w(t-1,c-1) + λ₃ β + λ₄ (1-β)
            w_tm1_c_coeff = (lambda_1 * beta) + lambda_2
            w_tm1_cm1_coeff = lambda_5 * (1 - beta)
            constant = lambda_3 * beta + lambda_4 * (1 - beta)
            # w(t,0) = ((1 - β) + λ₆ β) w(t - 1, 0)
            w_tm1_0_coeff = (1 - beta) + (lambda_6 * beta)

            if c == 0:
                return w_tm1_0_coeff * w(t - 1, 0)
            else:
                return (w_tm1_c_coeff * w(t - 1, c)) + (w_tm1_cm1_coeff * w(t - 1, c - 1)) + constant

        # To avoid max recursion depth errors
        if game.gr.num_problems > sys.getrecursionlimit() // 10:
            for m in range(0, M + 1, sys.getrecursionlimit() // 10):
                _ = w(m, C)

        sol = w(
            t=M,
            c=C
        )
        w.cache_clear()

        del w

        if sol > 1.0:
            if floats_are_close(sol, 1.0):
                sol = 1.0
                maybe_warn("Warning: Safety > 1.0, but it's close!")
            else:
                raise ValueError(f"Safety > 1.0: {sol}")
        return float(sol)


class CalculateUsefulnessViaRecurrenceForDeterministicThresholdBTS(CalculateUsefulnessManager):
    _accepted_BTS_Class = DeterministicThresholdBTS

    def _calculate_usefulness_inner(
            self,
            game: TrustedMonitoringGame,
            bts: SimpleTMBlueTeamPolicy,
            should_verify=True,
    ) -> float:

        @lru_cache(maxsize=None)
        def u(t, c):
            if t == 0:
                return 0
            elif c == 0:
                q_defer = bts.get_q_defer(t, c)
                prob_no_defer_given_no_bd = 1 - game.ma.get_prob_safe_exceeds_q(q_defer)
                return prob_no_defer_given_no_bd + u(t - 1, 0)
            else:
                q_audit = bts.get_q_audit(t, c)
                q_defer = bts.get_q_defer(t, c)
                human_FPR = game.ma.get_prob_audit_false_alarm()

                prob_audit_given_no_bd = game.ma.get_prob_safe_exceeds_q(q_audit)
                if q_defer < q_audit:
                    prob_defer_given_no_bd = 0.0
                else:
                    prob_defer_or_audit_gvn_no_bd = game.ma.get_prob_safe_exceeds_q(q_defer)
                    prob_defer_given_no_bd = prob_defer_or_audit_gvn_no_bd - ((1 - human_FPR) * prob_audit_given_no_bd)

                # Will we get the usefulness this round?
                out = (1 - prob_defer_given_no_bd)
                # Include the usefulness from the next round if we audit
                out += ((1 - human_FPR) * prob_audit_given_no_bd) * u(t - 1, c - 1)
                # Include the usefulness from the next round if we defer or continue
                out += (1 - prob_audit_given_no_bd) * u(t - 1, c)

                return out

        C = game.gr.num_coins
        M = game.gr.num_problems

        # To avoid max recursion depth errors
        if game.gr.num_problems > sys.getrecursionlimit() // 10:
            for m in range(0, M + 1, sys.getrecursionlimit() // 10):
                _ = u(m, C)

        ur = u(M, C)

        u.cache_clear()
        del u

        return ur / M


class CalculateSafetyViaRecurrenceForSimpleBTS(CalculateSafetyViaRecurrenceForDeterministicThresholdBTS):
    _accepted_BTS_Class = SimpleTMBlueTeamPolicy
    _accepted_RTS_Class = SimpleTMRedTeamPolicy

    def _get_lambdas(
            self,
            game: TrustedMonitoringGame,
            bts: DeterministicThresholdBTS,
            m: int,
            c: int
    ) -> tuple:
        return super()._get_lambdas(game, bts, m, c)

    def _calculate_safety_inner(self, game: TrustedMonitoringGame, bts: SimpleTMBlueTeamPolicy,
                                rts: SimpleTMRedTeamPolicy) -> float:
        assert isinstance(bts, SimpleTMBlueTeamPolicy)

        C = game.gr.num_coins
        M = game.gr.num_problems

        beta = rts.beta
        # note that SimpleBTS doesn't use M or C to calculate lambdas, so we can do it just the once
        lambda_1, lambda_2, lambda_3, lambda_4, lambda_5, lambda_6 = self._get_lambdas(game, bts, M, C)
        # w(t,c) = (λ₁ β + λ₂) w(t-1,c) + λ₅ (1 - β) w(t-1,c-1) + λ₃ β + λ₄ (1-β)
        w_tm1_c_coeff = (lambda_1 * beta) + lambda_2
        w_tm1_cm1_coeff = lambda_5 * (1 - beta)
        constant = lambda_3 * beta + lambda_4 * (1 - beta)
        # w(t,0) = ((1 - β) + λ₆ β) w(t - 1, 0)
        w_tm1_0_coeff = (1 - beta) + (lambda_6 * beta)

        @lru_cache(maxsize=None)
        def w(
                t: int,  # Time left in number of rounds (0->finished)
                c: int,  # Coins left (each audit costs one coin)
        ) -> float:
            if t == 0:
                return 1
            elif c == 0:
                return w_tm1_0_coeff * w(t - 1, 0)
            else:
                return (w_tm1_c_coeff * w(t - 1, c)) + (w_tm1_cm1_coeff * w(t - 1, c - 1)) + constant

        # To avoid max recursion depth errors
        if game.gr.num_problems > sys.getrecursionlimit() // 10:
            for m in range(0, M + 1, sys.getrecursionlimit() // 10):
                _ = w(m, C)

        sol = w(
            t=M,
            c=C
        )
        w.cache_clear()

        del w

        if sol > 1.0:
            if floats_are_close(sol, 1.0):
                sol = 1.0
                maybe_warn("Warning: Safety > 1.0, but it's close!")
            else:
                raise ValueError(f"Safety > 1.0: {sol}")
        return float(sol)


class CalculateSafetyViaRecurrenceForSimpleBTSMixedRTS(CalculateSafetyViaRecurrenceForSimpleBTS):
    _accepted_BTS_Class = SimpleTMBlueTeamPolicy
    _accepted_RTS_Class = MixedRedTeamPolicy

    def _get_lambdas(
            self,
            game: TrustedMonitoringGame,
            bts: DeterministicThresholdBTS,
            m: int,
            c: int
    ) -> tuple:
        return super()._get_lambdas(game, bts, m, c)

    def _calculate_safety_inner(self, game: TrustedMonitoringGame, bts: SimpleTMBlueTeamPolicy,
                                rts: MixedRedTeamPolicy) -> float:
        weighted_sum = 0
        for (prob, simple_rts) in zip(rts.ps, rts.policies):
            assert isinstance(simple_rts, SimpleTMRedTeamPolicy)
            safety = super()._calculate_safety_inner(game, bts, simple_rts)
            weighted_safety = prob * safety
            weighted_sum += weighted_safety
        return weighted_sum


class CalculateUsefulnessViaRecurrenceForSimpleBTS(CalculateUsefulnessViaRecurrenceForDeterministicThresholdBTS):
    _accepted_BTS_Class = SimpleTMBlueTeamPolicy

    def _calculate_usefulness_inner(self, game: TrustedMonitoringGame, bts: DeterministicThresholdBTS) -> float:
        return super()._calculate_usefulness_inner(game, bts)


# TODO make this adaptable to using many deterministic blue team strategies
class CalculateSafetyViaRecurrenceForMixedBTS(CalculateSafetyManager):
    _accepted_BTS_Class = MixedBlueTeamPolicy
    _accepted_RTS_Class = SimpleTMRedTeamPolicy

    def _calculate_safety_inner(
            self,
            game: TrustedMonitoringGame,
            bts: MixedBlueTeamPolicy,
            rts: SimpleTMRedTeamPolicy,
    ) -> float:
        c = CalcSafetyManagerFlexibleDefault()
        policies = bts.policies
        ps = bts.ps
        safety_mixed = 0.0
        for policy, p in zip(policies, ps):
            s = c.calculate_safety(game, policy, rts)
            safety_mixed += s * p

        if safety_mixed > 1.0:
            raise ValueError(f"Safety > 1.0: {safety_mixed}")
        return safety_mixed


class CalculateUsefulnessViaRecurrenceForMixedBTS(CalculateUsefulnessManager):
    _accepted_BTS_Class = MixedBlueTeamPolicy

    def _calculate_usefulness_inner(
            self,
            game: TrustedMonitoringGame,
            bts: MixedBlueTeamPolicy,
    ) -> float:
        c = CalcUsefulnessManagerFlexibleDefault()
        policies = bts.policies
        ps = bts.ps
        usefulness_mixed = 0.0

        for policy, p in zip(policies, ps):
            u = c.calculate_usefulness(game, policy)
            usefulness_mixed += u * p

        return usefulness_mixed


CSM_FLEXI = CalcSafetyManagerFlexibleDefault()
CUM_FLEXI = CalcUsefulnessManagerFlexibleDefault()
