"""

Perplexity Metric:
-------------------------------------------------------
Class for calculating perplexity from AttackResults

"""

import torch

from textattack.attack_results import FailedAttackResult, SkippedAttackResult
from textattack.metrics import Metric
import textattack.shared.utils


class Perplexity(Metric):
    def __init__(self, model_name="gpt2"):
        self.all_metrics = {}
        self.original_candidates = []
        self.successful_candidates = []

        if model_name == "gpt2":
            from transformers import GPT2LMHeadModel, GPT2Tokenizer

            self.ppl_model = GPT2LMHeadModel.from_pretrained("gpt2")
            self.ppl_model.to(textattack.shared.utils.device)
            self.ppl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
            self.ppl_model.eval()
            self.max_length = self.ppl_model.config.n_positions
        else:
            from transformers import AutoModelForMaskedLM, AutoTokenizer

            self.ppl_model = AutoModelForMaskedLM.from_pretrained(model_name)
            self.ppl_tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.ppl_model.to(textattack.shared.utils.device)
            self.ppl_model.eval()
            self.max_length = self.ppl_model.config.max_position_embeddings

        self.stride = 512

    def calculate(self, results):
        """Calculates average Perplexity on all successfull attacks using a
        pre-trained small GPT-2 model.

        Args:
            results (``AttackResult`` objects):
                Attack results for each instance in dataset

        Example::


            >> import textattack
            >> import transformers
            >> model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
            >> tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
            >> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
            >> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper)
            >> dataset = textattack.datasets.HuggingFaceDataset("glue", "sst2", split="train")
            >> attack_args = textattack.AttackArgs(
                num_examples=1,
                log_to_csv="log.csv",
                checkpoint_interval=5,
                checkpoint_dir="checkpoints",
                disable_stdout=True
            )
            >> attacker = textattack.Attacker(attack, dataset, attack_args)
            >> results = attacker.attack_dataset()
            >> ppl = textattack.metrics.quality_metrics.Perplexity().calculate(results)
        """
        self.results = results
        self.original_candidates_ppl = []
        self.successful_candidates_ppl = []

        for i, result in enumerate(self.results):
            if isinstance(result, FailedAttackResult):
                continue
            elif isinstance(result, SkippedAttackResult):
                continue
            else:
                self.original_candidates.append(
                    result.original_result.attacked_text.text.lower()
                )
                self.successful_candidates.append(
                    result.perturbed_result.attacked_text.text.lower()
                )

        ppl_orig = self.calc_ppl(self.original_candidates)
        ppl_attack = self.calc_ppl(self.successful_candidates)

        self.all_metrics["avg_original_perplexity"] = round(ppl_orig, 2)

        self.all_metrics["avg_attack_perplexity"] = round(ppl_attack, 2)

        return self.all_metrics

    def calc_ppl(self, texts):

        with torch.no_grad():
            text = " ".join(texts)
            eval_loss = []
            input_ids = torch.tensor(
                self.ppl_tokenizer.encode(text, add_special_tokens=True)
            ).unsqueeze(0)
            # Strided perplexity calculation from huggingface.co/transformers/perplexity.html
            for i in range(0, input_ids.size(1), self.stride):
                begin_loc = max(i + self.stride - self.max_length, 0)
                end_loc = min(i + self.stride, input_ids.size(1))
                trg_len = end_loc - i
                input_ids_t = input_ids[:, begin_loc:end_loc].to(
                    textattack.shared.utils.device
                )
                target_ids = input_ids_t.clone()
                target_ids[:, :-trg_len] = -100

                outputs = self.ppl_model(input_ids_t, labels=target_ids)
                log_likelihood = outputs[0] * trg_len

                eval_loss.append(log_likelihood)

        return torch.exp(torch.stack(eval_loss).sum() / end_loc).item()
