import sys
import torch
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

import json

class Perplexity:
    """
    A class for calculating the perplexity of a language model.
    """

    def __init__(self, model, tokenizer, dataset_path='wikitext', dataset_name=None, split='test', text_column='text'):
        """
        Calculate perplexity using the same method as seen in llama.cpp.

        Parameters
        ----------
        model : AutoModelForCausalLM
            The language model for which the perplexity is calculated.
        tokenizer : AutoTokenizer
            The tokenizer corresponding to the model.
        device : str, optional
            The device to run the calculations on. If auto, the device that your model uses
            will be the device used for these calculations. Default is 'auto'.
        dataset_path : str, optional
            The path to the dataset on the Hugging Face dataset hub. Default is 'wikitext'.
        dataset_name : str, optional
            The name of the dataset. Default is None.
        split : str, optional
            The split of the dataset to use. Default is 'test'.
        text_column : str, optional
            The name of the column in the dataset that contains the text data. Default is 'text'.
        """
        self._model = model
        self._tokenizer = tokenizer
        self._dataset_path = dataset_path
        self._dataset_name = dataset_name
        self._split = split
        self._text_column = text_column
        self._text, self.prefix_text = self._prepare_data()
    
    def _get_device(self):
        if torch.backends.mps.is_available():
            return 'mps'
        elif torch.cuda.is_available():
            return 'cuda:0'
        else:
            return 'cpu'
    
    def _prepare_data(self):
        """
        Prepares the dataset by loading and formatting.

        Returns
        -------
        str
            The formatted dataset as a single string.
        """
        with open(self._dataset_path, 'r') as f:
            lines = f.readlines()
        data = [json.loads(line.strip()) for line in lines]

        import re
        ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
        # Format the text column of the dataset
        text_list = []
        prefix_text_list = []
        for i in range(len(data)):
            item = data[i]
            text_list.append(item['question'] + item['output'])
            prefix_text_list.append(item['question'])
        return text_list, prefix_text_list


        # with open(self._dataset_path, 'r') as f:
        #     lines = f.readlines()
        # data = [json.loads(line.strip()) for line in lines]
        # # Format the text column of the dataset
        # text_list = data
        # return ''.join(text_list)

    @staticmethod
    def softmax(logits):
        """
        Static method for applying the softmax function.

        Parameters
        ----------
        logits : np.ndarray
            The input to the softmax function.

        Returns
        -------
        np.ndarray
            The output of the softmax function.
        """
        e_x = np.exp(logits - np.max(logits))
        return e_x / e_x.sum(axis=0)

    def calculate_perplexity(self, n_ctx=512, n_batch=512):
        """
        Calculates the perplexity of the language model.

        Parameters
        ----------
        n_ctx : int
            The context size.
        n_batch : int
            The batch size.
        
        Returns
        -------
        list
            The list of perplexity scores calculated.
        """
        # Tokenize the text
        self._tokenizer.model_max_length = sys.maxsize
        
        all_perplexity = []
        for batch_idx in range(len(self._text)):
            tokens = self._tokenizer(self._text[batch_idx], truncation=False, return_tensors='pt').input_ids.to(self._model.device)
            prefix_tokens = self._tokenizer(self.prefix_text[batch_idx], truncation=False, return_tensors='pt').input_ids.to(self._model.device)
            prefix_num = prefix_tokens.shape[1]
            nll = 0.0  # Negative log likelihood
            count = 0  # Counter for processed tokens
            curr_ppl = 0

            with tqdm(range(1), desc="Perplexity: - ") as progress:
                for i in progress:
                    # Process each batch of tokens
                    n_ctx = tokens.shape[1]
                    n_batch = tokens.shape[1]
                    nll, count = self._process_batch(i, n_ctx, n_batch, tokens, nll, count, prefix_num)

                    # Calculate and display the current perplexity
                    curr_ppl = np.exp(nll / count)
                    all_perplexity.append(curr_ppl)
                    progress.set_description(f"Perplexity: {curr_ppl:.4f}")
        print(sum(all_perplexity)/len(all_perplexity))
        return all_perplexity

    def _process_batch(self, i, n_ctx, n_batch, tokens, nll, count, prefix_token=None):
        """
        Processes each batch of tokens.

        Parameters
        ----------
        i : int
            The batch index.
        n_ctx : int
            The context size.
        n_batch : int
            The batch size.
        tokens : torch.Tensor
            The tokenized text.
        nll : float
            The current negative log likelihood.
        count : int
            The current count of processed tokens.

        Returns
        -------
        float
            The updated negative log likelihood.
        int
            The updated count of processed tokens.
        """
        start = i * n_ctx
        end = start + n_ctx

        num_batches = (n_ctx + n_batch - 1) // n_batch

        logits = []

        for j in range(num_batches):
            batch_start = start + j * n_batch
            batch_size = min(end - batch_start, n_batch)

            token_org = tokens[0][batch_start].item()

            if j == 0:
                # Replace the first token with the BOS token
                tokens[0][batch_start] = self._tokenizer.bos_token_id

            # Compute the logits for the current batch of tokens
            batch_logits = self._compute_batch_logits(tokens, batch_start, batch_size)

            tokens[0][batch_start] = token_org

            logits.append(batch_logits)
        
        # We rely on the fact that attention in the forward pass only looks at previous
        # tokens here, so the logits returned for each token are an accurate representation
        # of what the model would have predicted at that point.
        # 
        # Example, we have a context window of 512, we will compute perplexity for each of the
        # last 256 tokens.  Then, we split the input up into context window size chunks to
        # process the entire prompt.

        # for j in range(min(512, n_ctx // 2), n_ctx - 1):
        for j in range(prefix_token, n_ctx - 1):
            tok_logits = logits[0][0][j].cpu().numpy()
            # Compute the probability of the next token
            prob = self.softmax(tok_logits)[tokens[0][start + j + 1]]

            # Update the negative log likelihood and the count of processed tokens
            nll += -np.log(prob, where=prob>0)
            count += 1

        return nll, count

    def _compute_batch_logits(self, tokens, batch_start, batch_size):
        """
        Computes the logits for a batch of tokens.

        Parameters
        ----------
        tokens : torch.Tensor
            The tokenized text.
        batch_start : int
            The start index of the batch.
        batch_size : int
            The size of the batch.

        Returns
        -------
        torch.Tensor
            The logits for the batch of tokens.
        """
        # Compute the logits without keeping track of gradients
        with torch.no_grad():
            outputs = self._model(tokens[:, batch_start:batch_start+batch_size])
        return outputs.logits.detach()