from abc import ABC
from model.utils import pca_twice
import torch
from transformers import BertForMaskedLM


class AntBrain(torch.nn.Module, ABC):
    def __init__(self,
                 model_name,
                 hash_size,
                 tokenizer=None):
        super().__init__()
        self.model_name = model_name
        self.model = BertForMaskedLM.from_pretrained(self.model_name)
        self.predict_hash = torch.nn.Linear(768, hash_size)
        self.predict_hash_singleton = torch.nn.Linear(768, hash_size)
        self.config = self.model.config
        self.tokenizer = tokenizer

    def forward(self, input_ids, attention_mask, labels, label_indices, singleton_labels, singleton_label_indices):
        outputs = self.model(input_ids=input_ids,
                             attention_mask=attention_mask,
                             output_hidden_states=True)
        last_hidden_states = outputs.hidden_states[-1]
        batch_size = label_indices.size(0)
        num_indices = label_indices.size(1)

        # predict hash
        expanded_label_indices = label_indices.unsqueeze(-1).expand(batch_size, num_indices, last_hidden_states.size(2))
        expanded_last_hidden_states = torch.gather(last_hidden_states,  1, expanded_label_indices)
        prediction = self.predict_hash(expanded_last_hidden_states)

        loss_fnc = torch.nn.BCEWithLogitsLoss()
        prediction = prediction.view(-1, prediction.shape[-1])
        labels = labels.view(-1, prediction.shape[-1]).float()
        loss = loss_fnc(prediction, labels)

        # predict singleton
        batch_size = singleton_label_indices.size(0)
        num_indices = singleton_label_indices.size(1)
        expanded_singleton_label_indices = singleton_label_indices.unsqueeze(-1).\
            expand(batch_size, num_indices, last_hidden_states.size(2))
        expanded_last_hidden_states_singleton = torch.gather(last_hidden_states, 1, expanded_singleton_label_indices)
        prediction_singleton = self.predict_hash_singleton(expanded_last_hidden_states_singleton)

        prediction_singleton = prediction_singleton.view(-1, prediction_singleton.shape[-1])
        singleton_labels = singleton_labels.view(-1, prediction_singleton.shape[-1]).float()
        loss_singleton = loss_fnc(prediction_singleton, singleton_labels )

        outputs = {"loss": loss+loss_singleton}
        return outputs

    def get_embedding(self, input_ids, attention_mask):
        with torch.no_grad():
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
            last_hidden_states = outputs.hidden_states[-1]
            attention_mask = attention_mask.unsqueeze(-1)
            drug = (last_hidden_states * attention_mask).sum(1) / attention_mask.sum(1)
            return drug

    def get_pca_embedding(self, input_ids, attention_mask):
        with torch.no_grad():
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
            last_hidden_states = outputs.hidden_states[-1]
            drug = []
            for i in range(len(last_hidden_states)):
                hidden_states = last_hidden_states[i][:10]
                hidden_states = pca_twice(hidden_states, 64, 10)
                drug.append(hidden_states.unsqueeze(0))
            drug = torch.cat(drug)
            drug = drug.view(drug.size(0), -1)
            return drug
