import torch
from transformers import LogitsWarper
import numpy as np
import time


class DenoiseBase:
    """
    Base class for attacking distributions with fixed-group green-listed tokens.

    Args:
        reference_model: The reference model used for comparison.
        strength: The strength of the attack, indicating how much to remove green-listed tokens.
        attack_type: The type of attack to be performed.
    """

    def __init__(
        self,
        reference_model,
        strength=2.0,
        attack_type=None,
    ):
        self.reference_model = reference_model
        self.strength = strength
        self.attack_type = attack_type

    def get_weight(self, scores):
        probs = torch.softmax(scores, dim=-1)
        top_20_probs, _ = torch.topk(probs, k=20, dim=-1)
        signal = torch.mean(top_20_probs[:, :-1] - top_20_probs[:, 1:], dim=-1)
        num_bins = 100
        bins = torch.linspace(0, 1, num_bins + 1)
        weight = torch.bucketize(signal, bins.to(signal.device)).float() / num_bins
        weight = torch.clamp(weight, 0, 1)
        return weight


class DenoiseWatermarkLogitsWarper(DenoiseBase, LogitsWarper):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __call__(
        self, input_ids: torch.Tensor, scores: torch.Tensor
    ) -> torch.FloatTensor:
        if self.attack_type == "our_attack":
            weight = self.get_weight(scores)
            weight = torch.where(weight < self.strength, torch.tensor(0.0), weight)
            if torch.bernoulli(weight).item() == 1:
                new_logits = scores
            else:
                device = scores.device
                self.reference_model.to(device)
                with torch.no_grad():
                    new_logits = self.reference_model(input_ids).logits[:, -1, :]
                self.reference_model.cpu()
            new_logits = new_logits
        else:
            raise ValueError("Invalid type!")
        return new_logits
