from abc import ABC
import torch
from transformers import BertForMaskedLM


class AntBrain(torch.nn.Module, ABC):
    def __init__(self,
                 model_name,
                 prediction_objective='all',
                 tokenizer=None):
        super().__init__()
        self.model_name = model_name
        self.model = BertForMaskedLM.from_pretrained(self.model_name)
        self.predict_3d = torch.nn.Linear(768, 3)
        self.prediction_objective = prediction_objective
        self.config = self.model.config
        self.tokenizer = tokenizer

    def forward(self, input_ids, attention_mask, labels, mask_input_ids):
        # lm loss
        outputs = self.model(input_ids=mask_input_ids, attention_mask=attention_mask, labels=labels,
                             output_hidden_states=True)
        lm_loss = outputs.loss
        loss = lm_loss
        outputs = {"loss": loss, "lm_loss": lm_loss, 'batch_size': len(input_ids)}
        return outputs

    def get_embedding(self, input_ids, attention_mask):
        with torch.no_grad():
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
            last_hidden_states = outputs.hidden_states[-1]
            attention_mask = attention_mask.unsqueeze(-1)
            drug = (last_hidden_states * attention_mask).sum(1) / attention_mask.sum(1)
            return drug

    def get_fingerprints(self, ant_miner, input_ids, attention_mask):
        with torch.no_grad():
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True,
                                 output_attentions=True)
            fingerprints = ant_miner.get_fingerprints(outputs.attentions, input_ids, attention_mask, self.tokenizer)
            return fingerprints

