import torch.nn.functional as F
import torch

class good_and_bad_model(torch.nn.Module):
    def __init__(self,good_model,bad_model,num_virtual_tokens=0):
        super(good_and_bad_model, self).__init__()
        self.good_model = good_model
        self.bad_model = bad_model
        self.num_virtual_tokens = num_virtual_tokens
    def forward(self,input_ids,good_masks,bad_masks,labels=None):
        good_hf_logits = self.good_model(input_ids=input_ids).logits
        good_hf_logits = good_hf_logits[:, self.num_virtual_tokens:, :]
        good_hf_logits = good_hf_logits.contiguous()
        good_hf_loss = F.cross_entropy(good_hf_logits.permute(0, 2, 1), labels.to(good_hf_logits.device), reduction='none')
        good_hf_loss = (good_hf_loss * good_masks.to(good_hf_loss.device)).mean()

        bad_hf_logits = self.bad_model(input_ids=input_ids).logits
        bad_hf_logits = bad_hf_logits[:, self.num_virtual_tokens:, :]
        bad_hf_logits = bad_hf_logits.contiguous()
        bad_hf_loss = F.cross_entropy(bad_hf_logits.permute(0, 2, 1), labels.to(bad_hf_logits.device), reduction='none')
        bad_hf_loss = (bad_hf_loss * bad_masks.to(bad_hf_loss.device)).mean()

        total_loss = good_hf_loss + bad_hf_loss

        good_mask_temp = good_masks.unsqueeze(-1).expand_as(good_hf_logits)
        bad_mask_temp = bad_masks.unsqueeze(-1).expand_as(bad_hf_logits)
        good_hf_logits[good_mask_temp == 0] = 0
        bad_hf_logits[bad_mask_temp == 0] = 0
        final_logits = good_hf_logits + bad_hf_logits
        final_masks = good_masks + bad_masks

        return {"loss": total_loss,"logits":final_logits}