import torch
import torch.nn.functional as F
from transformers import LogitsProcessor
import logging

logger = logging.getLogger(__name__)

class PrimalThresholdProcessor(LogitsProcessor):
    def __init__(self,
                 alpha=1.5,
                 k_max=50,
                 epsilon=1e-9,
                 device="cuda"):
        if alpha not in [1.5, 2.0]:
            raise ValueError(f"alpha must be 1.5 or 2.0, but got {alpha}")
        self.alpha = alpha
        self.k_max = k_max
        self.epsilon = epsilon
        self.device = device

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        scores = scores.to(self.device)
        probs = F.softmax(scores, dim=-1)
        batch_size, vocab_size = probs.shape
        sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
        k = min(self.k_max, vocab_size)
        top_k_probs = sorted_probs[:, :k]
        top_k_indices = sorted_indices[:, :k]
        modified_probs_sorted = torch.zeros_like(sorted_probs)

        if self.alpha == 1.5:
            sqrt_p_k = torch.sqrt(top_k_probs)
            sum_sqrt_p_k = torch.sum(sqrt_p_k, dim=-1, keepdim=True)
            sum_p_k = torch.sum(top_k_probs, dim=-1, keepdim=True)
            inner_term = sum_sqrt_p_k**2 - k * (sum_p_k - 1.0)
            inner_term = torch.clamp(inner_term, min=0.0)
            numerator = torch.sqrt(inner_term) - sum_sqrt_p_k
            adjustment = numerator / k
            modified_sqrt_p_k = sqrt_p_k + adjustment
            modified_probs_k = modified_sqrt_p_k**2
            modified_probs_sorted[:, :k] = modified_probs_k

        elif self.alpha == 2.0:
            sum_p_k = torch.sum(top_k_probs, dim=-1, keepdim=True)
            adjustment = (1.0 - sum_p_k) / k
            modified_probs_k = top_k_probs + adjustment
            modified_probs_sorted[:, :k] = modified_probs_k

        modified_probs = torch.zeros_like(probs)
        modified_probs.scatter_(dim=-1, index=sorted_indices, src=modified_probs_sorted)
        modified_probs = F.normalize(modified_probs, p=1, dim=-1)
        mask = modified_probs > self.epsilon
        modified_logits = torch.where(
            mask,
            torch.log(modified_probs),
            torch.tensor(-float('inf'), device=modified_probs.device, dtype=modified_probs.dtype)
        )

        return modified_logits


class PrimalThresholdModelWrapper:
    def __init__(self, model, threshold_processor):
        self.model = model
        self.threshold_processor = threshold_processor
        
        for attr_name in dir(model):
            if not attr_name.startswith('_') and not hasattr(self, attr_name):
                try:
                    setattr(self, attr_name, getattr(model, attr_name))
                except (AttributeError, TypeError):
                    pass
    
    def __call__(self, *args, **kwargs):
        outputs = self.model(*args, **kwargs)
        
        if hasattr(outputs, "logits"):
            input_ids = None
            if len(args) > 0 and hasattr(args[0], "shape"):
                input_ids = args[0]
            elif "input_ids" in kwargs:
                input_ids = kwargs["input_ids"]
                
            if input_ids is not None and hasattr(outputs.logits, "shape"):
                outputs.logits = self.threshold_processor(input_ids, outputs.logits)
        
        return outputs
        
    def generate(self, *args, **kwargs):
        existing_processors = kwargs.get("logits_processor", [])
        
        class ThresholdLogitsAdapter:
            def __init__(self, processor):
                self.processor = processor
                
            def __call__(self, input_ids, scores):
                return self.processor(input_ids, scores)
        
        kwargs["logits_processor"] = existing_processors + [ThresholdLogitsAdapter(self.threshold_processor)]
        
        return self.model.generate(*args, **kwargs)
    
    def forward(self, *args, **kwargs):
        outputs = self.model.forward(*args, **kwargs)
        
        if hasattr(outputs, "logits"):
            input_ids = None
            if len(args) > 0 and hasattr(args[0], "shape"):
                input_ids = args[0]
            elif "input_ids" in kwargs:
                input_ids = kwargs["input_ids"]
                
            if input_ids is not None and hasattr(outputs.logits, "shape"):
                outputs.logits = self.threshold_processor(input_ids, outputs.logits)
        
        return outputs
    
    def __getattr__(self, name):
        return getattr(self.model, name)