from abc import ABC, abstractmethod
from dataclasses import dataclass

import numpy as np
from tqdm import trange


class RedTeamStrategy(ABC):
    @abstractmethod
    def __init__(self):
        pass

    @abstractmethod
    def act(self, obs) -> object:
        pass

    @abstractmethod
    def reset(self):
        pass


class BlueTeamPolicy(ABC):

    @abstractmethod
    def __init__(self):
        pass

    @abstractmethod
    def act(self, obs) -> object:
        pass

    @abstractmethod
    def reset(self):
        pass


@dataclass(frozen=True, eq=True, order=True, unsafe_hash=False)
class MixedBlueTeamPolicy(BlueTeamPolicy):
    policies: tuple[BlueTeamPolicy]
    ps: tuple[float]

    _cur_state = np.array([-1], dtype=np.int32)

    def __post_init__(self):
        assert len(self.policies) == len(self.ps)
        assert np.all(np.array(self.ps) >= 0.0)
        assert np.all(np.array(self.ps) <= 1.0)
        assert np.isclose(np.sum(self.ps), 1.0)

    @property
    def cur_policy(self):
        assert self._cur_state[0] != -1
        return self.policies[self._cur_state[0]]

    def act(self, obs) -> object:
        return self.cur_policy.act(obs)

    def reset(self):
        for policy in self.policies:
            policy.reset()
        self._cur_state[0] = np.random.choice(len(self.policies), p=self.ps)


@dataclass(frozen=True, eq=True, order=True, unsafe_hash=False)
class MixedRedTeamPolicy(RedTeamStrategy):
    policies: tuple[RedTeamStrategy]
    ps: tuple[float]

    _cur_state = np.array([-1], dtype=np.int32)

    def __post_init__(self):
        assert len(self.policies) == len(self.ps)
        assert np.all(np.array(self.ps) >= 0.0)
        assert np.all(np.array(self.ps) <= 1.0)
        assert np.isclose(np.sum(self.ps), 1.0)

    @property
    def cur_policy(self):
        assert self._cur_state[0] != -1
        return self.policies[self._cur_state[0]]

    def act(self, obs) -> object:
        return self.cur_policy.act(obs)

    def reset(self):
        for policy in self.policies:
            policy.reset()
        self._cur_state[0] = np.random.choice(len(self.policies), p=self.ps)


class ModellingAssumptions:
    def __init__(self):
        super().__init__()


class GameRules(ABC):
    def __init__(self):
        raise NotImplementedError


class Game(ABC):
    ma: ModellingAssumptions
    gr: GameRules

    @property
    def benign_red_team_policy(self):
        raise NotImplementedError

    @abstractmethod
    def simulate_one_game(
            self,
            blue_team: BlueTeamPolicy,
            red_team: RedTeamStrategy,
    ) -> tuple[bool, float, dict]:
        was_backdoored: bool = False
        usefulness: float = 0.0
        info: dict = dict()
        raise NotImplementedError

    def estimate_MC_safety(
            self,
            sample_size: int,
            blue_team: BlueTeamPolicy,
            red_team: RedTeamStrategy,
    ) -> tuple[float, dict]:
        sample_safeties = np.zeros((sample_size,))
        for i in trange(sample_size, delay=10):
            was_backdoored, _, _ = self.simulate_one_game(blue_team, red_team)
            sample_safeties[i] = (1.0 if not was_backdoored else 0.0)
        return np.mean(sample_safeties), dict()

    def estimate_MC_usefulness(
            self,
            sample_size: int,
            blue_team: BlueTeamPolicy
    ) -> tuple[float, dict]:
        sample_usefulnesses = np.zeros((sample_size,))
        red_team = self.benign_red_team_policy
        for i in range(sample_size):
            _, usefulness, _ = self.simulate_one_game(blue_team, red_team)
            sample_usefulnesses[i] = usefulness
        return np.mean(sample_usefulnesses), dict()

    @abstractmethod
    def __eq__(self, other):
        raise NotImplementedError

    @abstractmethod
    def __hash__(self):
        raise NotImplementedError
