from transformers import (
    AutoModelForMaskedLM,
    AutoTokenizer,
    AutoModel,
    GPT2LMHeadModel,
    GPTJForCausalLM,
)
from typing import List
import numpy as np
import logging
import cohere
import openai
import torch
import time
import os

logger = logging.getLogger()
openai.organization = "<INSERT PERSONAL KEY>"
openai.api_key = "<INSERT PERSONAL KEY>"


class Model:
    def _get_mlm_score(self, masked_input_text: str, input_text: str) -> float:
        """
        Uses a model to get a score for a masked prediction task.
        :param masked_input_text: A text with a mask token, e.g. "There is a [mask] on the table".
        :param input_text: The labels for the masked input, e.g. "There is a cat on the table".
        :return: the NLLL
        """
        raise NotImplementedError(
            "Masked language modeling not implemented for this class."
        )

    def _get_lm_score(self, input_text) -> float:
        """
        Uses a model to get a language modelling score (NLL) for a given input text
        :param input_text: a str text to score
        :return: perplexity
        """
        raise NotImplementedError("Language modeling not implemented for this class.")

    def get_mask_token(self):
        raise NotImplementedError("Get mask token not implemented for this class.")

    def get_model_score(self, input_texts: List, labels: List, task: str) -> float:
        """
        Gets a masked language modeling (mlm) or language modeling (lm) score for input_text.
        :param input_text:  a str text to score
        :param task: a task (mlm or lm)
        :param masked_input_text:  A text with a mask token, e.g. "There is a [mask] on the table".
        :return: a lm or mlm score float
        """
        if task == "lm":
            score = self._get_lm_score(labels)
        elif task == "mlm":
            assert None not in labels, "Provide masked text if task is mlm."
            score = self._get_mlm_score(input_texts, labels)
        else:
            raise ValueError("Unrecognized task %s" % task)
        return score

    def get_prediction_bias(
        self, input_texts: List[str], input_words: List[str], implicature: str
    ) -> dict:
        """
        Gets a language modeling score for the input texts associated with the input word of the same
        index. The input_texts should only differ in those words in input_words.
        Returns the preferred word based on the corresponding input_text with the best LM score.
        :param input_texts: two texts to score
        :param input_words: the word associated with each text
        :param implicature: the correct implicature
        :return: dict with scores per text, input_words that were scores, loss per word, and the preferred word.
        """
        loss_per_word = []
        assert implicature in input_words, (
            "Implicature %s not found in input_words." % implicature
        )
        for i, input_text in enumerate(input_texts):
            assert (
                input_words[i] in input_text.lower()
            ), "input_word(=%s) at idx %d not found in input_text(=%s) at same idx" % (
                input_words[i],
                i,
                input_text,
            )
            score = self.get_model_score(input_text, input_text, "lm")
            loss_per_word.append(score)

        index_of_min = loss_per_word.index(min(loss_per_word))
        preferred_word = input_words[index_of_min]
        correct_index = input_words.index(implicature)
        correct_score = loss_per_word[correct_index]
        false_scores = [
            loss for i, loss in enumerate(loss_per_word) if i != correct_index
        ]
        return {
            "correct_score": correct_score,
            "false_scores": false_scores,
            "preferred_word": preferred_word,
            "implicature": implicature,
            "scored_texts": input_texts,
        }


class HFModelWrapper(Model):
    """
    ...
    """

    CONFIG = {
        "gpt2": GPT2LMHeadModel,
        "bert": AutoModelForMaskedLM,
        "roberta": AutoModelForMaskedLM,
        "EleutherAI/gpt": GPTJForCausalLM,
    }
    TESTED_MODEL_IDS = {
        "gpt2-medium",
        "gpt2-large",
        "gpt2-xl",
        "bert-base-cased",
        "bert-base-uncased",
        "roberta-base",
        "roberta-large",
        "EleutherAI/gpt-j-6B",
    }
    BIG_MODELS = {"EleutherAI/gpt-j-6B"}

    def __init__(self, model_id: str):
        if model_id not in self.TESTED_MODEL_IDS:
            logger.warning("Requested model_id=%s is not tested." % model_id)
            model_class = AutoModel
        else:
            model_class = self.CONFIG[model_id.split("-")[0]]
        if model_id in self.BIG_MODELS:
            self._model = model_class.from_pretrained(
                model_id,
                revision="float32",
                torch_dtype=torch.float32,
                low_cpu_mem_usage=True,
            )
        else:
            self._model = model_class.from_pretrained(model_id)
        self._tokenizer = AutoTokenizer.from_pretrained(model_id)
        self._tokenizer.truncation_side = "left"
        self._mask_token = self._tokenizer.mask_token
        self._model.eval()
        self._eos_token = self._tokenizer.eos_token
        self._sep_token = self._tokenizer.sep_token
        self._device = "cpu"
        self._model_id = model_id

        if model_id.split("-")[0] == "gpt2":
            self._tokenizer.pad_token = self._tokenizer.eos_token
            self._model.config.pad_token_id = self._model.config.eos_token_id

    def to(self, device) -> None:
        self._device = device
        self._model.to(device)

    def get_mask_token(self):
        return self._mask_token

    def _tokenize_function(self, example_text: str):
        return self._tokenizer(
            example_text, return_tensors="pt", truncation=True, padding=True
        )

    def _get_mlm_score(
        self, masked_input_text: List[str], input_text: List[str]
    ) -> np.ndarray:
        """
        Uses a model to get a score for a masked prediction task.
        :param masked_input_text: A text with a mask token, e.g. "There is a [mask] on the table".
        :param input_text: The labels for the masked input, e.g. "There is a cat on the table".
        :return: the NLLL
        """
        assert self._mask_token, "Cannot do masked language modeling with this model."
        for masked_input, label in zip(masked_input_text, input_text):
            assert self._mask_token in masked_input, (
                "Trying to _get_mlm_score() for masked_text(=%s) without "
                "mask_token(=%s)." % (masked_input, self._mask_token)
            )
            assert (
                self._mask_token not in label
            ), "Trying to _get_mlm_score() for label(=%s) with " "mask_token(=%s)." % (
                label,
                self._mask_token,
            )
        with torch.no_grad():
            tokenized_input = self._tokenize_function(masked_input_text)
            tokenized_labels = self._tokenize_function(input_text)
            tokenized_labels = torch.where(
                tokenized_input.input_ids == self._tokenizer.mask_token_id,
                tokenized_labels.input_ids,
                -100,
            )
            tokenized_input = tokenized_input.to(self._device)
            tokenized_labels = tokenized_labels.to(self._device)
            score = self._model(**tokenized_input, labels=tokenized_labels)
            return score.loss.cpu().numpy()

    def _get_lm_score(self, input_text) -> np.ndarray:
        """
        Uses a model to get a language modelling score (NLL) for a given input text
        :param input_text: a str text to score
        :return: perplexity
        """
        with torch.no_grad():
            inputs = self._tokenize_function(input_text)
            labels = inputs["input_ids"].clone()
            labels[labels == self._tokenizer.pad_token_id] = -100
            inputs = inputs.to(self._device)
            labels = labels.to(self._device)
            outputs = self._model(**inputs, labels=labels)
            return outputs.loss.cpu().numpy()

    def get_model_logits(self, input_text, correct_word, wrong_word, mask_token=None):
        def tokenize_word(word):
            return self._tokenize_function(word)

        tokenized_input = self._tokenize_function(input_text)
        outputs = self._model(**tokenized_input)
        correct_token = tokenize_word(correct_word)["input_ids"].squeeze()
        wrong_token = tokenize_word(wrong_word)["input_ids"].squeeze()
        if len(correct_token.shape) > 0:
            correct_token = correct_token[1]
            wrong_token = wrong_token[1]
        if not mask_token:
            prediction_idx = -1
        else:
            prediction_idx = (
                tokenized_input.input_ids == self._tokenizer.mask_token_id
            )[0].nonzero(as_tuple=True)[0]
        correct_score = outputs.logits[:, prediction_idx, correct_token].item()
        wrong_score = outputs.logits[:, prediction_idx, wrong_token].item()
        return correct_score, wrong_score


class CohereModelWrapper(Model):

    MODELS = {
        "xl": "xlarge",
        "large": "large",
        "medium": "medium",
        "small": "small",
    }

    def __init__(self, model_size: str):
        assert model_size in self.MODELS, (
            "Chosen model_size=%s not available." % model_size
        )
        api_key = self._get_api_key()
        self._cohere_client = cohere.Client(api_key)
        self._model_id = self.MODELS[model_size]

    @staticmethod
    def _get_api_key():
        with open(os.path.join(os.getcwd(), "static/cohere_api_key.txt")) as infile:
            key = infile.read()
        return key

    def _get_lm_score(self, input_text):

        if isinstance(input_text, list):
            ppl = []
            for _input_text in input_text:
                prediction = self._cohere_client.generate(
                    model=self._model_id,
                    prompt=_input_text,
                    max_tokens=0,
                    temperature=1,
                    k=0,
                    p=0.75,
                    frequency_penalty=0,
                    presence_penalty=0,
                    stop_sequences=[],
                    return_likelihoods="ALL",
                )
                likelihood = prediction.generations[0].likelihood
                ppl.append(np.exp(-1.0 * likelihood))
        else:
            prediction = self._cohere_client.generate(
                model=self._model_id,
                prompt=input_text,
                max_tokens=0,
                temperature=1,
                k=0,
                p=0.75,
                frequency_penalty=0,
                presence_penalty=0,
                stop_sequences=[],
                return_likelihoods="ALL",
            )
            likelihood = prediction.generations[0].likelihood
            ppl = np.exp(-1.0 * likelihood)
        return ppl

    def _get_mlm_score(self, masked_input_text: str, input_text: str) -> float:
        raise ValueError("Cannot do masked language modeling with Cohere models.")

    def get_mask_token(self):
        raise ValueError("Cannot do masked language modeling with Cohere models.")

    def to(self, device):
        logger.info(f"Cohere model is cloud-based, unable to set device to {device}")


class OpenAIModel(Model):

    MODELS = {
        "davinci": "text-davinci-002",
        "davinci1": "text-davinci-001",
        "gpt3davinci": "davinci",
        "gpt3ada": "ada",
        "gpt3curie": "curie",
        "gpt3babbage": "babbage",
        "babbage": "text-babbage-001",
        "curie": "text-curie-001",
        "ada": "text-ada-001"
    }

    def __init__(self, model_engine: str, rate_limit=False):
        assert model_engine in self.MODELS, (
            "Chosen model_engine=%s not available." % model_engine
        )
        self._model_id = self.MODELS[model_engine]
        self._rate_limit = rate_limit

    def _get_lm_score(self, input_text):

        if isinstance(input_text, list):
            ppl = []
            for _input_text in input_text:
                prediction = openai.Completion.create(engine=self._model_id,
                                                      prompt=_input_text,
                                                      max_tokens=0,
                                                      logprobs=1,
                                                      echo=True)
                assert prediction.choices[0].logprobs.tokens[-1] in ["yes", "no", " no", " yes",
                                                                     "one", "two", " one", " two",
                                                                     "1", "2", " 1", " 2",
                                                                     "A", "B", " A", " B",], "Last token is not one of binary implicature options."
                likelihood = prediction.choices[0].logprobs.token_logprobs[-1]
                ppl.append(np.exp(-1.0 * likelihood))
                if self._rate_limit:
                    time.sleep(4)
        else:
            raise ValueError("This code doesn't work anymore.")
            prediction = openai.Completion.create(engine=self._model_id,
                                                  prompt=input_text,
                                                  max_tokens=0,
                                                  logprobs=0,
                                                  echo=True)
            likelihood = prediction.choices[0].logprobs.token_logprobs[-1]
            ppl = np.exp(-1.0 * likelihood)
            if self._rate_limit:
                time.sleep(4)
        return ppl

    def _get_mlm_score(self, masked_input_text: str, input_text: str) -> float:
        raise ValueError("Cannot do masked language modeling with Cohere models.")

    def get_mask_token(self):
        raise ValueError("Cannot do masked language modeling with Cohere models.")

    def to(self, device):
        logger.info(f"Cohere model is cloud-based, unable to set device to {device}")
