import torch
import torch.nn.functional as F
from torch import nn, optim
from transformers import AutoModelForCausalLM, AutoTokenizer
class HAPDAFineTuner(nn.Module):
    def __init__(self, ref_model_name, hum_model_name, mac_model_name):
        super(HAPDAFineTuner, self).__init__()
        self.ref_model = AutoModelForCausalLM.from_pretrained(ref_model_name)
        self.hum_model = AutoModelForCausalLM.from_pretrained(hum_model_name)
        self.mac_model = AutoModelForCausalLM.from_pretrained(mac_model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(ref_model_name)
    def forward(self, input_ids, attention_mask, labels):
        # Generate outputs from all models
        ref_outputs = self.ref_model(input_ids, attention_mask=attention_mask, labels=labels)
        hum_outputs = self.hum_model(input_ids, attention_mask=attention_mask, labels=labels)
        mac_outputs = self.mac_model(input_ids, attention_mask=attention_mask, labels=labels)
        # Calculate alignment loss
        hum_score = hum_outputs.loss
        mac_score = mac_outputs.loss
        ref_score_hum = ref_outputs.loss
        ref_score_mac = ref_outputs.loss
        # Alignment Loss
        delta_mac = torch.log(mac_score / ref_score_mac) - torch.log(hum_score / ref_score_hum)
        delta_hum = torch.log(hum_score / ref_score_hum) - torch.log(mac_score / ref_score_mac)
        align_loss = -torch.log(torch.sigmoid(delta_mac)) - torch.log(torch.sigmoid(delta_hum))
        # Calculate Distinctiveness Loss
        p_mac = F.softmax(mac_outputs.logits, dim=-1)
        p_hum = F.softmax(hum_outputs.logits, dim=-1)
        # Mixture Distribution
        m = 0.5 * (p_mac + p_hum)
        # KL Divergence
        kl_mac = F.kl_div(p_mac.log(), m, reduction='batchmean')
        kl_hum = F.kl_div(p_hum.log(), m, reduction='batchmean')
        js_divergence = 0.5 * (kl_mac + kl_hum)
        dis_loss = 2 * torch.log(torch.tensor(2.0)) - js_divergence
        # Total Loss
        total_loss = align_loss + lambda_param * dis_loss
        return total_loss