import datasets
import numpy as np
import torch
from torch.nn import CrossEntropyLoss
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer

import evaluate
from evaluate import logging
from datasets import Dataset


_CITATION = """"""

_DESCRIPTION = """"""

_KWARGS_DESCRIPTION = """
Args:
    model_id (str): model used for calculating Perplexity
    predictions (list of str): input text, each separate text snippet
        is one list entry.
    batch_size (int): the batch size to run texts through the model. Defaults to 16.
    add_start_token (bool): whether to add the start token to the texts,
        so the perplexity can include the probability of the first word. Defaults to True.
    device (str): device to run on, defaults to 'cuda' when available
Returns:
    perplexity: dictionary containing the perplexity scores for the texts
        in the input list, as well as the mean perplexity. If one of the input texts is
        longer than the max input length of the model, then it is truncated to the
        max length for the perplexity computation.

"""

class Perplexity():
    def __init__(self, model_id, device=None):
        if device is not None:
            assert device in ["gpu", "cpu", "cuda"], "device should be either gpu or cpu."
            if device == "gpu":
                self.device = "cuda"
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to(self.device)
        if 'llama' in model_id or 'alpaca' in model_id:
            self.tokenizer = LlamaTokenizer.from_pretrained(model_id)
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(model_id)

    
    def _info(self):
        return evaluate.MetricInfo(
            module_type="metric",
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            features=datasets.Features(
                {
                    "predictions": datasets.Value("string"),
                }
            ),
            reference_urls=["https://huggingface.co/docs/transformers/perplexity"],
        )

    def compute(
        self, dset_path, prompts, responses, batch_size: int = 16, add_start_token: bool = True,  max_length=None
    ):
        if self.tokenizer.pad_token is None and batch_size > 1:
            existing_special_tokens = list(self.tokenizer.special_tokens_map_extended.values())
            assert (
                len(existing_special_tokens) > 0
            ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1."
            self.tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})
            print(existing_special_tokens[0])

        if add_start_token and max_length:
            assert (
                self.tokenizer.bos_token is not None
            ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
            max_tokenized_len = max_length - 1
        else:
            max_tokenized_len = max_length

        dset = Dataset.load_from_disk(dset_path)
       
        encodings = dset
        
        padded_inputs = self.tokenizer.pad({
            "input_ids": encodings["input_ids"],
            "attention_mask": encodings["attention_mask"]
        }, padding="max_length", max_length=4096, return_tensors="pt")
        padded_labels = [label + [-100] * (4096 - len(label)) for label in encodings["labels"]]
        torch.set_printoptions(threshold=5000)
        
        encoded_texts = torch.tensor(padded_inputs["input_ids"]).to(self.device)
        attn_masks = torch.tensor(padded_inputs["attention_mask"]).to(self.device)
        encoded_labels = torch.tensor(padded_labels).to(self.device)
        
        

        ppls = []
        loss_fct = CrossEntropyLoss(reduction="none")

        for start_index in logging.tqdm(range(0, len(encoded_texts), batch_size)):
            end_index = min(start_index + batch_size, len(encoded_texts))
            encoded_batch = encoded_texts[start_index:end_index]
            encoded_label = encoded_labels[start_index:end_index]
            attn_mask = attn_masks[start_index:end_index]
            
            labels = encoded_label

            print(type(encoded_batch))
            with torch.no_grad():
                out_logits = self.model(encoded_batch, attention_mask=attn_mask).logits

            shift_logits = out_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            shift_attention_mask_batch = attn_mask[..., 1:].contiguous()

            bin_label = torch.where(shift_labels > 0, 1, 0)

            perplexity_batch = torch.exp(
                (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1)
                / bin_label.sum(1)
            )

            ppls += perplexity_batch.tolist()

        return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)}


checkpoint_path = '/data/b_ou/ckpts/output_11k_cllama_code_m2w_scrp_clueweb_wiki_2000/checkpoint-1770'
ppl = Perplexity(checkpoint_path)
res = ppl.compute('/data/b_ou/ckpts/data_cache/m2w_task/', prompts=[],responses=[],batch_size=8,add_start_token=False)

print(res)


# nohup python perplexity.py > perplexity_output_11k_cllama_code_m2w_scrp_clueweb_wiki_2000_task_b8.out 2>&1 &

