import torch
import torch.nn.functional as F
from transformers import LogitsProcessor
import logging

logger = logging.getLogger(__name__)

class PrimalThresholdProcessor(LogitsProcessor):
    def __init__(self,
                 alpha=1.5,
                 k_max=50,
                 mu=0.01,
                 epsilon=1e-9,
                 device="cuda"): 
        if alpha not in [1.5, 2.0]:
            raise ValueError(f"alpha must be 1.5 or 2.0, but got {alpha}")
        self.alpha = alpha
        self.k_max = k_max
        self.epsilon = epsilon
        self.device = device
        self.mu = mu
        
        self.alpha_minus_1 = self.alpha - 1
        self.inv_alpha_minus_1 = 1.0 / self.alpha_minus_1
        self.phi_denom = self.alpha * self.alpha_minus_1
        self.all_best_k_values = []

    def f(self, x):
        return torch.pow(x, self.alpha_minus_1) / self.alpha_minus_1

    def phi(self, x):
        return torch.pow(x, self.alpha) / self.phi_denom

    def d_phi(self, y, x):
        return self.phi(y) - self.phi(x) - self.f(x) * (y - x)

    def compute_loss(self, hat_p, p, k_val):
        div_term = self.d_phi(hat_p, p).sum(dim=-1) 

        k_tensor = torch.tensor(k_val, device=div_term.device, dtype=div_term.dtype)
        l0_term = self.mu * k_tensor
        
        return div_term + l0_term

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        scores = scores.to(self.device)
        probs = F.softmax(scores, dim=-1)
        batch_size, vocab_size = probs.shape

        sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
        
        k_upper_bound = min(self.k_max, vocab_size)

        min_losses_per_item = torch.full((batch_size,), float('inf'), device=self.device, dtype=scores.dtype)
        best_k_per_item = torch.ones((batch_size,), device=self.device, dtype=torch.long) 
        final_modified_probs_sorted = torch.zeros_like(sorted_probs)

        for k_candidate in range(1, k_upper_bound + 1):
            current_k_top_k_probs = sorted_probs[:, :k_candidate]
            
            hat_p_k_values = torch.zeros_like(current_k_top_k_probs)

            k_float = torch.tensor(k_candidate, device=self.device, dtype=scores.dtype)

            if self.alpha == 1.5:
                sqrt_p_k = torch.sqrt(current_k_top_k_probs)
                sum_sqrt_p_k = torch.sum(sqrt_p_k, dim=-1, keepdim=True)
                sum_p_k = torch.sum(current_k_top_k_probs, dim=-1, keepdim=True)
                
                inner_term = sum_sqrt_p_k**2 - k_float * (sum_p_k - 1.0)
                inner_term = torch.clamp(inner_term, min=0.0) 
                
                numerator = torch.sqrt(inner_term) - sum_sqrt_p_k
                adjustment = numerator / k_float
                
                modified_sqrt_p_k = sqrt_p_k + adjustment
                hat_p_k_values = modified_sqrt_p_k**2

            elif self.alpha == 2.0:
                sum_p_k = torch.sum(current_k_top_k_probs, dim=-1, keepdim=True)
                adjustment = (1.0 - sum_p_k) / k_float
                hat_p_k_values = current_k_top_k_probs + adjustment

            current_hat_p_sorted = torch.zeros_like(sorted_probs)
            current_hat_p_sorted[:, :k_candidate] = hat_p_k_values
            current_sum = torch.sum(current_hat_p_sorted[:, :k_candidate], dim=-1, keepdim=True)
            current_hat_p_sorted[:, :k_candidate] = current_hat_p_sorted[:, :k_candidate] / current_sum
            current_losses = self.compute_loss(current_hat_p_sorted, sorted_probs, k_candidate)

            update_mask = current_losses < min_losses_per_item
            min_losses_per_item[update_mask] = current_losses[update_mask]
            best_k_per_item[update_mask] = k_candidate
            final_modified_probs_sorted[update_mask] = current_hat_p_sorted[update_mask]
        
        self.all_best_k_values.append(best_k_per_item.cpu().tolist())

        modified_probs = torch.zeros_like(probs)
        modified_probs.scatter_(dim=-1, index=sorted_indices, src=final_modified_probs_sorted)
        
        modified_probs = F.normalize(modified_probs, p=1, dim=-1) 

        mask = modified_probs > self.epsilon
        modified_logits = torch.where(
            mask,
            torch.log(modified_probs), 
            torch.tensor(-float('inf'), device=modified_probs.device, dtype=modified_probs.dtype)
        )

        return modified_logits

    def get_all_best_k_values(self):
        flattened_list = []
        for k_list_for_batch_step in self.all_best_k_values:
            flattened_list.extend(k_list_for_batch_step)
        return flattened_list

    def clear_best_k_values(self):
        self.all_best_k_values = []


class PrimalThresholdModelWrapper:
    def __init__(self, model, threshold_processor):
        self.model = model
        self.threshold_processor = threshold_processor
        
        for attr_name in dir(model):
            if not attr_name.startswith('_') and not hasattr(self, attr_name):
                try:
                    setattr(self, attr_name, getattr(model, attr_name))
                except (AttributeError, TypeError):
                    pass
    
    def __call__(self, *args, **kwargs):
        outputs = self.model(*args, **kwargs)
        
        if hasattr(outputs, "logits"):
            input_ids = None
            if len(args) > 0 and hasattr(args[0], "shape"):
                input_ids = args[0]
            elif "input_ids" in kwargs:
                input_ids = kwargs["input_ids"]
                
            if input_ids is not None and hasattr(outputs.logits, "shape"):
                outputs.logits = self.threshold_processor(input_ids, outputs.logits)
        
        return outputs
        
    def generate(self, *args, **kwargs):
        existing_processors = kwargs.get("logits_processor", [])
        
        class ThresholdLogitsAdapter:
            def __init__(self, processor):
                self.processor = processor
                
            def __call__(self, input_ids, scores):
                return self.processor(input_ids, scores)
        
        kwargs["logits_processor"] = existing_processors + [ThresholdLogitsAdapter(self.threshold_processor)]
        
        return self.model.generate(*args, **kwargs)
    
    def forward(self, *args, **kwargs):
        outputs = self.model.forward(*args, **kwargs)
        
        if hasattr(outputs, "logits"):
            input_ids = None
            if len(args) > 0 and hasattr(args[0], "shape"):
                input_ids = args[0]
            elif "input_ids" in kwargs:
                input_ids = kwargs["input_ids"]
                
            if input_ids is not None and hasattr(outputs.logits, "shape"):
                outputs.logits = self.threshold_processor(input_ids, outputs.logits)
        
        return outputs
    
    def __getattr__(self, name):
        return getattr(self.model, name)
