import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
class HAPDACalibration:
    def __init__(self, hum_model_name, mac_model_name, pxy_model_name, tokenizer_name):
        self.hum_model = AutoModelForCausalLM.from_pretrained(hum_model_name)
        self.mac_model = AutoModelForCausalLM.from_pretrained(mac_model_name)
        self.pxy_model = AutoModelForCausalLM.from_pretrained(pxy_model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    def compute_prediction_disagreement(self, input_ids, attention_mask):
        # Get predictions from the two models
        hum_outputs = self.hum_model(input_ids, attention_mask=attention_mask, output_hidden_states=False)
        mac_outputs = self.mac_model(input_ids, attention_mask=attention_mask, output_hidden_states=False)
        # Calculate absolute disagreement for each token
        hum_probs = F.softmax(hum_outputs.logits, dim=-1)
        mac_probs = F.softmax(mac_outputs.logits, dim=-1)
        disagreement = torch.abs(hum_probs - mac_probs)
        return disagreement
    def compute_entropy(self, probs):
        # Compute entropy for the probability distribution
        return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1)
    def compute_uncertainty(self, input_ids, attention_mask):
        # Get predictions from the two models
        hum_outputs = self.hum_model(input_ids, attention_mask=attention_mask, output_hidden_states=False)
        mac_outputs = self.mac_model(input_ids, attention_mask=attention_mask, output_hidden_states=False)
        # Compute entropy for both models
        hum_probs = F.softmax(hum_outputs.logits, dim=-1)
        mac_probs = F.softmax(mac_outputs.logits, dim=-1)
        # Calculate entropy for each token for both models
        E_hum = self.compute_entropy(hum_probs)
        E_mac = self.compute_entropy(mac_probs)
        # Compute the uncertainty
        V = hum_probs.size(-1) 
        uncertainty = (E_hum + E_mac) / (2 * torch.log(torch.tensor(V, dtype=torch.float32)))
        return uncertainty
    def compute_token_weights(self, input_ids, attention_mask):
        # Step 1: Calculate disagreement 
        disagreement = self.compute_prediction_disagreement(input_ids, attention_mask)
        # Step 2: Calculate uncertainty 
        uncertainty = self.compute_uncertainty(input_ids, attention_mask)
        # Calculate token weights
        weights = disagreement * (1 - uncertainty)
        # Normalize token weights
        normalized_weights = weights / weights.sum(dim=-1, keepdim=True)
        return normalized_weights
    def compute_final_detection_score(self, input_ids, attention_mask):
        # Compute token weights
        weights = self.compute_token_weights(input_ids, attention_mask)
        # Step 3: Calculate Likelihood score from proxy model
        pxy_outputs = self.pxy_model(input_ids, attention_mask=attention_mask, output_hidden_states=False)
        pxy_probs = F.softmax(pxy_outputs.logits, dim=-1)
        log_likelihoods = torch.log(pxy_probs)
        # Compute the final detection score
        final_score = torch.sum(weights * log_likelihoods, dim=-1)
        return final_score

