import random
from models.classifier import Classifier
from transformers import RobertaTokenizer, \
    RobertaForSequenceClassification

from utils.constants import ROBERTA
from math import ceil
import numpy as np
from utils.constants import DEVICE
import torch


class Roberta(Classifier):
    def __init__(self, pretrained_model_path=ROBERTA, num_labels=5, to_device=False):
        self.num_labels = num_labels
        self.pretrained_model_path = pretrained_model_path

        super().__init__(pretrained_model_path=self.pretrained_model_path, num_labels=self.num_labels)
        self.lm = self.classifier.roberta

    def get_representation_model(self):
        return self.lm

    def get_tokenizer(self):
        return RobertaTokenizer.from_pretrained(ROBERTA, return_tensors='pt')

    def get_classifier(self, to_device=False):
        classifier = RobertaForSequenceClassification.from_pretrained(self.pretrained_model_path,
                                                                      num_labels=self.num_labels)
        if to_device:
            classifier.to(DEVICE)
        return classifier

    def get_embeddings(self, lst_texts, model_already_in_device=False):
        tokenized_set = self.tokenizer(lst_texts, truncation=True, return_tensors='pt', padding=True)
        lm = self.lm
        if not model_already_in_device:
            lm.to(DEVICE)
        lm.eval()
        # get the predictions batch per batch
        lst = []
        for i in range(ceil(len(tokenized_set['input_ids']) / self.batch_size)):
            x_batch = {k: v[i * self.batch_size:(i + 1) * self.batch_size].to(DEVICE) for k, v in
                       tokenized_set.items()}
            with torch.no_grad():
                lst += lm(**x_batch).last_hidden_state[:, 0, :].detach().cpu().tolist()
            del x_batch
            torch.cuda.empty_cache()

        np_embeddings = [np.array(e) for e in lst]
        if not model_already_in_device:
            lm.cpu()
        return np_embeddings

    def get_model_group(self):
        return 'group 2'
