import numpy as np
import torch
from lm import nn_log_probs
from lm import nn_next_token_probs

class LMPolicy:
    def __init__(self, model, vocab_size, temperature=1.0):
        self.model = model
        self.vocab_size = vocab_size
        self.temperature = temperature

    def next_token_probs(self, sequence):
        with torch.no_grad():
            preds = self.model(torch.tensor([sequence]).cuda()).logits[0,-1,:].cpu()
            #print("Shape of model logits", preds.shape)
            probs = torch.softmax(preds, dim=-1)
            #print("Next-token probabilities:", probs)
            return probs

    def batch_next_token_probs(self, sequences):
        with torch.no_grad():
            preds = self.model(torch.tensor(sequences).cuda()).logits[:,-1,:].cpu()
            probs = torch.softmax(preds, dim=-1)
            return probs

    def next_token(self, sequence):
        probs = np.array(self.next_token_probs(sequence))
        total = sum(probs)
        probs /= total
        return np.random.choice(self.vocab_size, p=probs)

    def next_k_tokens(self, sequence, k):
        L = []
        for i in range(k):
            L.append(self.next_token(sequence + L))
        return L

    def sequence_prob(self, sequence):
        # sequence should include BOS and EOS
       
        log_prob = nn_log_probs(self.model, [sequence], batch_size=1)[0]
        return np.exp(log_prob)
