from libmab.learners import Learner
from abc import abstractclassmethod

import numpy as np


class Attacker(Learner):
    def __init__(self, n_arms, target):
        super().__init__(n_arms)
        self.target = target

    @abstractclassmethod
    def attack(self, reward, arm):
        pass


class OracleAttacker():

    def __init__(self, n_arms, target, means, epsilon: float = .05):
        self.n_arms = n_arms
        self.target = target
        self.means = means
        self.epsilon = epsilon

    def attack(self, reward, arm) -> float:
        if arm == self.target:
            return 0
        corruption = self.means[arm] - self.means[self.target] + self.epsilon
        return corruption
