from math import ceil

import numpy as np
import torch

from models.classifier import Classifier
from transformers import RobertaTokenizer, \
    RobertaForSequenceClassification, AutoTokenizer, AutoModel

from utils.constants import SENTENCE_TRANS, DEVICE

from models.model import Model


class SentenceTrans(Model):

    def get_model_group(self):
        return 'group 2'

    def __init__(self, pretrained_model_path=SENTENCE_TRANS, num_labels=5):
        self.num_labels = num_labels
        self.pretrained_model_path = pretrained_model_path
        self.lm = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
        self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
        super().__init__(pretrained_model_path=self.pretrained_model_path)

    def get_representation_model(self):
        return self.lm

    def get_tokenizer(self):
        return self.tokenizer

    def get_embeddings(self, lst_texts):
        tokenized_set = self.tokenizer(lst_texts, truncation=True, return_tensors='pt', padding=True)
        lm = self.lm
        lm.to(DEVICE)
        lm.eval()
        # get the predictions batch per batch
        lst = []
        for i in range(ceil(len(tokenized_set['input_ids']) / self.batch_size)):
            x_batch = {k: v[i * self.batch_size:(i + 1) * self.batch_size].to(DEVICE) for k, v in
                       tokenized_set.items()}
            with torch.no_grad():
                model_output = lm(**x_batch)
                lst += mean_pooling(model_output, x_batch['attention_mask']).detach().cpu().tolist()
            del x_batch
            torch.cuda.empty_cache()

        np_embeddings = [np.array(e) for e in lst]
        lm.cpu()
        return np_embeddings


def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
