# Create a BERT-sequence classification model class with a method call 'power' that takes in a list of token_ids and returns a tensor of power scores
from transformers import BertTokenizer, AutoModelForSequenceClassification
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class PowerRegularizer(nn.Module):
    def __init__(self, regularizer_lm, model_path, weights, utility_thresh=[-1.0, -1.0, -3.0]):
        # morality, utility, power
        super(PowerRegularizer, self).__init__()
        self.model = AutoModelForSequenceClassification.from_pretrained(
            regularizer_lm, num_labels=3)
        self.model.load_state_dict(torch.load(model_path, map_location=device), strict=False) # TAYLOR - I hope no error was introduced here - there were extra params
        self.valence = torch.tensor([-1.0, -1.0, 1.0])
        self.threshold = torch.tensor(utility_thresh)
        self.weights = torch.tensor([float(w) for w in weights.split(",")])
        assert len(self.weights) == 3
        self.morality_threshold = -1.0
        self.utility_threshold = 1.0
        self.power_threshold = 3.0

    def __call__(self, acts):
        return self.power(acts)

    @torch.no_grad()
    def power(self, acts):
        acts = {k:v.to('cpu') for k, v in acts.items()}
        scores = self.model(**acts).logits
        return torch.sum((scores * self.valence > self.threshold) * self.weights).to(device)

