from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import lru_cache
from typing import Optional

import numpy as np
import torch
from icecream import ic

from ai_control_games.generic_game import MixedBlueTeamPolicy
from ai_control_games.trusted_monitor.tm_game import TMBlueObsType, RunAuditOrDeferType, TMBlueTeamPolicy, \
    TrustedMonitoringGame
from ai_control_games.distributions import DistOver0to1
from ai_control_games.other_utils import floats_are_close


class DeterministicThresholdBTS(TMBlueTeamPolicy, ABC):
    safe_dist: DistOver0to1

    # NOTE this function must be deterministic
    @abstractmethod
    def get_q_audit(self, m: int, c: int) -> float:
        raise NotImplementedError

    # NOTE this function must be deterministic
    @abstractmethod
    def get_q_defer(self, m: int, c: int) -> float:
        raise NotImplementedError

    @lru_cache(maxsize=100)
    def get_threshold_from_q(self, q: float) -> float:
        # p_safe_less_than_threshold = 1.0 - q == self.safe_dist.cdf(threshold)
        threshold = self.safe_dist.inv_cdf(1.0 - q)
        return threshold

    def act(self, obs: TMBlueObsType) -> RunAuditOrDeferType:
        s, m, c = obs
        q_defer = self.get_q_defer(m, c)
        defer_threshold = self.get_threshold_from_q(q_defer)

        if c == 0:
            if s > defer_threshold:
                return RunAuditOrDeferType.DEFER
            else:
                return RunAuditOrDeferType.RUN
        else:
            q_audit = self.get_q_audit(m, c)
            audit_threshold = self.get_threshold_from_q(q_audit)
            if s > audit_threshold:
                return RunAuditOrDeferType.AUDIT
            elif s > defer_threshold:
                return RunAuditOrDeferType.DEFER
            else:
                return RunAuditOrDeferType.RUN

    def reset(self):
        pass

    def __post_init__(self):
        for m in [1, 2]:
            for c in [0, 1]:
                for f in [self.get_q_audit, self.get_q_defer]:
                    q1 = f(m, c)
                    q2 = f(m, c)
                    assert 0 <= q1 <= 1, q1
                    assert q1 == q2, (q1, q2, "should be deterministic")


@dataclass(frozen=True, eq=True, order=True, unsafe_hash=False)
class SimpleTMBlueTeamPolicy(DeterministicThresholdBTS):
    # TODO standardise use of q_audit vs q_defer
    q_audit: float
    q_defer: float
    safe_dist: DistOver0to1

    def __post_init__(self):
        assert 0 <= self.q_audit <= 1.0, self.q_audit
        assert 0 <= self.q_defer <= 1.0, self.q_defer

    def get_q_audit(self, m: int, c: int) -> float:
        return self.q_audit

    def get_q_defer(self, m: int, c: int) -> float:
        return self.q_defer


@dataclass(frozen=True, eq=True, order=True, unsafe_hash=False)
class DiscreteSimpleTMBlueTeamPolicy(SimpleTMBlueTeamPolicy):
    # TODO standardise use of q_audit vs q_defer
    q_audit: float
    q_defer: float
    safe_dist: DistOver0to1

    def __post_init__(self):
        assert 0 <= self.q_audit <= 1.0, self.q_audit
        assert 0 <= self.q_defer <= 1.0, self.q_defer
        assert floats_are_close(self.q_audit, round(self.q_audit, 3)), self.q_audit
        assert floats_are_close(self.q_defer, round(self.q_defer, 3)), self.q_defer


class MixTwoSimpleBTS(MixedBlueTeamPolicy):
    def __init__(self, qa1, qd1, qa2, qd2, p, safe_dist: DistOver0to1):
        assert 0 <= p <= 1.0
        assert 0 <= qa1 <= 1.0
        assert 0 <= qd1 <= 1.0
        assert 0 <= qa2 <= 1.0
        assert 0 <= qd2 <= 1.0
        policy1 = SimpleTMBlueTeamPolicy(qa1, qd1, safe_dist)
        policy2 = SimpleTMBlueTeamPolicy(qa2, qd2, safe_dist)
        super().__init__(
            policies=(policy1, policy2),
            ps=(p, 1 - p)
        )


@dataclass(frozen=True, eq=True, order=True, unsafe_hash=False)
class VariableAuditTMBlueTeamPolicy(DeterministicThresholdBTS):
    q_audit: float
    q_defer: float
    init_audit_proportion: float
    safe_dist: DistOver0to1

    def get_q_audit(self, m: int, c: int) -> float:
        # TODO - FIX THE BUG HERE - DOESN"T WORK FOR NUM PROBLEMS != 100
        effective_audit_rate = c / m
        adapted_q_audit = self.q_audit * effective_audit_rate / self.init_audit_proportion
        if adapted_q_audit > 1.0:
            adapted_q_audit = 1.0
        elif adapted_q_audit < 0.0:
            adapted_q_audit = 0.0
        if np.isnan(adapted_q_audit):
            ic(self.q_audit, effective_audit_rate, self.init_audit_proportion, m, c)
        return adapted_q_audit

    def get_q_defer(self, m: int, c: int) -> float:
        return self.q_defer


class VariableAuditFixedDefer(DeterministicThresholdBTS):

    def get_q_audit(self, m: int, c: int) -> float:
        q_a = float(self.audit_qs[m, c])
        assert 0.0 <= q_a <= 1.0, f"q_a(m,c) is not in [0, 1]: {q_a}=q_a({m},{c})"
        return q_a

    def get_q_defer(self, m: int, c: int) -> float:
        return self.defer_q

    def __init__(
            self,
            game: TrustedMonitoringGame,
            defer_q: float,
            audit_qs: Optional[torch.Tensor] = None
    ):
        self.g = game
        self.M = game.gr.num_problems
        self.C: int = game.gr.num_coins
        self.defer_q = defer_q
        if audit_qs is None:
            self.audit_qs = torch.ones((self.M + 1, self.C + 1)) * -1
        else:
            self.audit_qs = audit_qs
