import torch
from ddt import data, ddt, unpack
from torch.nn.functional import one_hot

from bandit2.bandit_attacker import BanditUniformRandomAttacker
from bandit2.bandit_env import BanditEnv, BanditUniformRandomController

from .utils import TestCase


@ddt
class BanditUniformRandomAttackerTestCase(TestCase):
    def test_uniform_random_attacker__has_attack_within_budget(self, tolerance: float = 1e-5):
        n_arms = 100
        budget = 3
        attacker = BanditUniformRandomAttacker(1, n_arms, 0.3, budget)

        assert attacker.means.norm() <= budget + tolerance

    def test_reward_plus_attack__gets_poisoned_reward(self, tolerance: float = 1e-5):
        n_arms = 10
        budget = 3
        n_steps = 100
        env = BanditEnv.sample(1, n_steps, n_arms)
        attacker = BanditUniformRandomAttacker(1, n_arms, 0.3, budget)
        controller = BanditUniformRandomController(1, n_steps, n_arms)

        dataset = env.deploy(controller, attacker, 1.0, 1.0)

        rew_diff = torch.abs(dataset.rewards - dataset.rewards_original)
        assert torch.all(torch.logical_and(rew_diff > tolerance, rew_diff < budget + tolerance))
