from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import lru_cache
from typing import NamedTuple

import numpy as np
import torch
import random

from ai_control_games.generic_game import GameRules, ModellingAssumptions, Game, BlueTeamPolicy, RedTeamStrategy
from ai_control_games.other_utils import floats_are_close
from ai_control_games.distributions import DistOver0to1


class TMRedAction(NamedTuple):
    bd_or_not: bool

    BD = (True,)
    DONT = (False,)


class TMBlueObsType(NamedTuple):
    sus_score: float
    num_problems_left: int
    num_coins_left: int


# TODO replace this cursed typing
class RunAuditOrDeferType(NamedTuple):
    """
    The type of action the blue team can take
    0: audit
    1: defer
    2: run
    """
    run_audit_or_defer: int

    AUDIT = (0,)
    DEFER = (1,)
    RUN = (2,)
    ALL_ACTIONS = (AUDIT, DEFER, RUN)


class TMBlueTeamPolicy(BlueTeamPolicy, ABC):

    @abstractmethod
    def act(self, obs: TMBlueObsType) -> RunAuditOrDeferType:
        raise NotImplementedError

    @abstractmethod
    def reset(self):
        raise NotImplementedError


class TMRedTeamPolicy(RedTeamStrategy, ABC):
    def act(self, obs) -> TMRedAction:
        raise NotImplementedError

    @abstractmethod
    def reset(self):
        raise NotImplementedError

    def __repr__(self):
        name = self.__class__.__name__
        vars = self.__dict__
        vars_strings = [f"{k}={v:4f}" for k, v in vars.items()]
        return f"{name}({', '.join(vars_strings)})"


@dataclass(frozen=True, eq=True, order=True, unsafe_hash=False)
class SimpleTMRedTeamPolicy(TMRedTeamPolicy):
    beta: float

    def __post_init__(self):
        assert 0 <= self.beta <= 1.0, f"beta={self.beta} must be between 0 and 1.0"

    def act(self, obs: None) -> TMRedAction:
        if self.beta == 0.0 or np.random.uniform(0, 1, 1).item() > self.beta:
            return (False,)
        else:
            return (True,)

    def reset(self):
        pass

    def __repr__(self):
        return f"RTS(beta={self.beta:.4f})"


@dataclass(frozen=True, eq=True, order=True, unsafe_hash=False)
class TMGameRules(GameRules):
    num_problems: int
    audit_proportion: float

    @property
    def num_coins(self):
        return int(self.num_problems * self.audit_proportion)

    def __post_init__(self):
        if not (self.num_problems * self.audit_proportion).is_integer():
            actual = self.num_problems * self.audit_proportion
            if not floats_are_close(actual, float(round(actual))):
                raise ValueError(f"{self.num_problems=} * {self.audit_proportion=} must be an integer")


_SAMPLE_BSZ = int(1e4)


# TODO make this a general wrapper for DistOver0to1 instancess
class SampleBatchManager:
    icdf = None

    _sample_batch = torch.zeros((_SAMPLE_BSZ,))
    _sample_index = _SAMPLE_BSZ - 1

    def set_icdf(self, icdf):
        self.icdf = icdf
        self.resample_batch()

    def resample_batch(self):
        self._sample_batch = torch.distributions.uniform.Uniform(0.0, 1.0).sample((_SAMPLE_BSZ,))
        if self.icdf is not None:
            self._sample_batch.apply_(self.icdf)
        self._sample_index = 0

    # TODO I think there's an off-by-one here that's wasing one sample, but it's only a small inefficiency
    def sample(self) -> float:
        if self._sample_index == _SAMPLE_BSZ - 1:
            self.resample_batch()
        self._sample_index += 1
        return float(self._sample_batch[self._sample_index].item())

    def sample_many(self, num_samples: int) -> np.ndarray:
        if self._sample_index + num_samples >= _SAMPLE_BSZ:
            self.resample_batch()
        start = self._sample_index
        end = start + num_samples
        return self._sample_batch[start:end].numpy()


# TODO tidy up this class (possibly separating algorthmic speedup from modelling assumptions)
@dataclass(order=True, unsafe_hash=False, eq=True, frozen=True)
class TMModellingAssumptions(ModellingAssumptions):
    bd_dist: DistOver0to1
    safe_dist: DistOver0to1
    human_FNR: float = 0
    human_FPR: float = 0

    _bd_batch_manager = SampleBatchManager()
    _sf_batch_manager = SampleBatchManager()

    def __post_init__(self):
        self.validate_sus_dists()
        self._sf_batch_manager.set_icdf(self.safe_dist.inv_cdf)
        self._bd_batch_manager.set_icdf(self.bd_dist.inv_cdf)

    def __hash__(self):
        return hash((self.bd_dist, self.safe_dist))

    def __eq__(self, other):
        return (self.bd_dist == other.bd_dist) and (self.safe_dist == other.safe_dist)

    @lru_cache(maxsize=1)
    def get_q_a_s_tensor(self, grid_resolution: int):
        return torch.linspace(0.0, 1.0, grid_resolution, dtype=torch.float32)

    @lru_cache(maxsize=1)
    def get_tpr_tensor(self, grid_resolution: int):
        q_a_s_tensor = self.get_q_a_s_tensor(grid_resolution)
        return torch.tensor([self.get_prob_bd_exceeds_q(q) for q in q_a_s_tensor], dtype=torch.float32)

    @lru_cache(maxsize=1)
    def get_fpr_tensor(self, grid_resolution: int):
        q_a_s_tensor = self.get_q_a_s_tensor(grid_resolution)
        return torch.tensor([self.get_prob_safe_exceeds_q(q) for q in q_a_s_tensor], dtype=torch.float32)

    def get_prob_safe_exceeds_q(self, q: float) -> float:
        return q

    @lru_cache(maxsize=1001)
    def get_prob_bd_exceeds_q(self, q: float) -> float:
        threshold = self.safe_dist.inv_cdf(1.0 - q)
        p_bd_exceeds_threshold = 1.0 - self.bd_dist.cdf(threshold)
        return p_bd_exceeds_threshold

    def get_prob_exceeds_q(self, q: float, is_bd: bool) -> float:
        if is_bd:
            return self.get_prob_bd_exceeds_q(q)
        else:
            return self.get_prob_safe_exceeds_q(q)

    def get_prob_audit_breach(self) -> float:
        return self.human_FNR

    def get_prob_audit_false_alarm(self) -> float:
        return self.human_FPR

    def sample_audit_breach(self) -> bool:
        return (random.random() < self.get_prob_audit_breach())

    def sample_audit_false_alarm(self) -> bool:
        return (random.random() < self.get_prob_audit_false_alarm())

    def sample_many_sus_scores(self, num: int, is_bd: bool) -> np.ndarray:
        if is_bd:
            return self._bd_batch_manager.sample_many(num)
        else:
            return self._sf_batch_manager.sample_many(num)

    def sample_sus_score(self, is_bd: bool) -> float:
        if is_bd:
            return self._bd_batch_manager.sample()
        else:
            return self._sf_batch_manager.sample()

    def validate_sus_dists(self):
        self.safe_dist.validate_dist()
        self.bd_dist.validate_dist()
        assert floats_are_close(0.0, self.get_prob_bd_exceeds_q(0.0)), self.get_prob_bd_exceeds_q(0.0)
        assert floats_are_close(1.0, self.get_prob_safe_exceeds_q(1.0)), self.get_prob_safe_exceeds_q(1.0)


@dataclass(frozen=True, eq=True, order=True, unsafe_hash=False)
class TrustedMonitoringGame(Game):
    ma: TMModellingAssumptions
    gr: TMGameRules

    @property
    def benign_red_team_policy(self):
        return SimpleTMRedTeamPolicy(beta=0.0)

    def simulate_one_game(
            self,
            blue_team: TMBlueTeamPolicy,
            red_team: SimpleTMRedTeamPolicy,
    ) -> tuple[bool, float, dict]:
        blue_team.reset()
        red_team.reset()

        # Init state and transition
        state = "start"
        blue_action = None
        red_action = red_team.act(None)

        m = self.gr.num_problems
        c = self.gr.num_coins
        sus_score = round(self.ma.sample_sus_score(red_action[0]), 6) # rounding to ensure finite observation space
        info = {
            "num_bds": 0,
            "num_coins_used": 0,
            "num_problems_attempted": 0,
            "outcome": None
        }

        state = (m, c, red_action[0])
        blue_obs = (sus_score, m, c)
        red_obs = None

        # now iterate M times, tracking number solved by U
        num_solved_by_U = 0
        for round in range(self.gr.num_problems):
            # all sink states are represented as strings
            if isinstance(state, str):
                continue

            # if game continues, get new actions
            blue_action = blue_team.act(blue_obs)
            red_action = red_team.act(red_obs)

            # use state and actions to decide next state and observations
            m, c, is_bd = state
            sus_score = round(self.ma.sample_sus_score(red_action[0]), 6)

            if blue_action == RunAuditOrDeferType.RUN:

                # if is_bd, go to fail state
                if is_bd:
                    info["outcome"] = "backdoor successful"
                    state = "fail"
                    blue_obs = None
                    red_obs = None

                # otherwise, safe use of U's solution
                else:
                    num_solved_by_U += 1
                    state = (m - 1, c, red_action[0])
                    blue_obs = (sus_score, m - 1, c)
                    red_obs = None

            elif blue_action == RunAuditOrDeferType.DEFER:

                # no coin used, just update state and observation
                state = (m - 1, c, red_action[0])
                blue_obs = (sus_score, m - 1, c)
                red_obs = None

            elif blue_action == RunAuditOrDeferType.AUDIT:
                assert c > 0
                info["num_coins_used"] += 1

                # if is_bd, need to know whether audit was successful
                if is_bd:
                    human_error = self.ma.sample_audit_breach()

                    # if human makes error, go to fail state
                    if human_error:
                        info["outcome"] = "backdoor successful (human audit error)"
                        state = "fail"
                        blue_obs = None
                        red_obs = None

                    # else, go to success state
                    else:
                        info["outcome"] = "caught backdoor"
                        state = "success"
                        blue_obs = None
                        red_obs = None

                # if not is_bd, need to know whether false alarm raised
                else:
                    human_error = self.ma.sample_audit_false_alarm()

                    # if human makes error, go to success state, but mark outcome
                    if human_error:
                        info["outcome"] = "false alarm (human audit error)"
                        state = "success"
                        blue_obs = None
                        red_obs = None

                    # else, run U's solution and continue
                    else:
                        num_solved_by_U += 1
                        state = (m - 1, c - 1, red_action[0])
                        blue_obs = (sus_score, m - 1, c - 1)
                        red_obs = None

        # no outcome reached after M rounds means safe conclusion to game
        if info["outcome"] is None:
            info["outcome"] = "finished safely"

        was_backdoored = (state == "fail")
        usefulness = (num_solved_by_U / self.gr.num_problems)

        return was_backdoored, usefulness, info
