import torch
import numpy as np
import torch.nn.functional as F
from torch import nn

class DiffusionRM(nn.Module):
    def __init__(self, base_model, tokenizer):
        super().__init__()
        self.config = base_model.config
        self.num_padding_at_beginning = 0
        self.pretrained_model = base_model
        self.tokenizer = tokenizer
        self.PAD_ID = tokenizer.pad_token_id
        self.v_head = nn.Linear(self.config.hidden_size, 1, bias=False)
        self.loss_function_with_score = torch.nn.MSELoss()
        # self.loss_function_with_rank = torch.nn.CrossEntropyLoss()
        
    def gradient_checkpointing_enable(self,gradient_checkpointing_kwargs):
        self.pretrained_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)

    def forward(self, input_ids, kto_tags, output_hidden_states=False, return_dict=False, use_cache=False):
        labels = kto_tags
        outputs = self.pretrained_model(input_ids=input_ids, output_hidden_states=output_hidden_states)
        hidden_states = outputs.hidden_states[-1]
        bs = input_ids.shape[0]
        seq_len = input_ids.shape[1]
        features = []
        for i in range(bs):
            input_id = input_ids[i]
            hidden = hidden_states[i]
            c_inds = (input_id == self.PAD_ID).nonzero()
            c_ind = c_inds[0].item() if len(c_inds) > 0 else seq_len
            features.append(hidden[c_ind - 1])

        features_stack = torch.stack(features)
        rewards = self.v_head(features_stack)
        labels = labels.view(-1,1)
        labels = labels.to(torch.bfloat16)
        criterion = nn.MSELoss()
        loss = criterion(rewards, labels)

        return (outputs.logits, loss, rewards)

        

    


    