import copy
import logging
from collections import Counter
import torch
import datasets
import evaluate
import numpy as np
import torch.nn.functional as F

import accelerate
import transformers

from transformers import AutoTokenizer
from transformers import AutoConfig
from transformers import AutoModelForCausalLM
from transformers import BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
from transformers import StoppingCriteria
from transformers import StoppingCriteriaList


# from model_util.base_model import BaseModel, STOP_SEQUENCES

from abc import ABC, abstractmethod
from typing import List, Text

ROUGE = evaluate.load('rouge')

def check_answer_correctness(
    correct_answers: List[str],
    model_answers: List[str],
    rouge_threshold: float = 0.3,
) -> List[bool]:
    """
    Check whether a given answer is correct. This uses the heuristic by Kuhn et al. (2023), checking whether the ROUGE-L
    score is higher than some threshold.
    Additionally, we check via simply string matching whether the correct answer is included in the model answer.

    Parameters
    ----------
    correct_answers: List[str]
        Reference answers.
    model_answers: List[str]
        Model generations to compare to the reference answer.
    rouge_threshold: float
        Threshold of ROUGE-L scores over which an answer is deemed correct. Default is 0.3.

    Returns
    -------
    List[bool]
        Whether the given answer was deemed correct.
    """

    results = [
        # Add the second criterion to accommodate CoT answers that might be longer and therefore obtain lower ROUGE
        # scores but are still correct.
        res >= rouge_threshold or correct_answer.upper() in model_answer.upper() or model_answer.upper() in correct_answer.upper()
        for res, correct_answer, model_answer in zip(
            ROUGE.compute(
                predictions=model_answers,
                references=correct_answers,
                use_aggregator=False,
            )["rougeL"],
            correct_answers,
            model_answers,
        )
    ]

    return results

class BaseModel(ABC):

    stop_sequences: List[Text]

    @abstractmethod
    def predict(self, input_data, temperature):
        pass


class StoppingCriteriaSub(StoppingCriteria):
    """Stop generations when they match a particular text or token."""
    def __init__(self, stops, tokenizer, match_on='text', initial_length=None):
        super().__init__()
        self.stops = stops
        self.initial_length = initial_length
        self.tokenizer = tokenizer
        self.match_on = match_on
        if self.match_on == 'tokens':
            self.stops = [torch.tensor(self.tokenizer.encode(i)).to('cuda') for i in self.stops]
            print(self.stops)

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        del scores  # `scores` arg is required by StoppingCriteria but unused by us.
        for stop in self.stops:
            if self.match_on == 'text':
                generation = self.tokenizer.decode(input_ids[0][self.initial_length + 1:], skip_special_tokens=False)
                match = stop in generation
            elif self.match_on == 'tokens':
                # Can be dangerous due to tokenizer ambiguities.
                match = stop in input_ids[0][-len(stop):]
            else:
                raise
            if match:
                return True
        return False
    
class HuggingfaceModel(BaseModel):
    """Hugging Face Model."""

    def __init__(self, model_name, stop_sequences=None, max_new_tokens=None, **kwargs):
        if max_new_tokens is None:
            raise
        self.max_new_tokens = max_new_tokens

        if stop_sequences == 'default':
            stop_sequences = STOP_SEQUENCES

        kwargs = kwargs

        self.tokenizer =  AutoTokenizer.from_pretrained(
                f"{model_name}", device_map="auto",
                token_type_ids=None)
        
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name, torch_dtype=torch.bfloat16, device_map="auto",**kwargs,)
        
        self.model_name = model_name
        if stop_sequences is None:
            self.stop_sequences = [self.tokenizer.eos_token]
        else:
            self.stop_sequences = stop_sequences + [self.tokenizer.eos_token]
        self.token_limit = 4096 if 'Llama-2' in model_name else 2048
        
    def predict_with_question_embedding(self, batch, embedding_model, temperature, return_full=False, do_branch=False, topk=1):
        
            inputs = batch["prompt"]
            questions = batch["question"]
            ground_truth = batch["answer"]

            if isinstance(inputs, list):
                inputs_token = self.tokenizer(inputs, return_tensors="pt")['input_ids'].to('cuda')
            if 'llama' in self.model_name.lower() or 'falcon' in self.model_name or 'mistral' in self.model_name.lower():
                # if 'token_type_ids' in inputs:  # Some HF models have changed.
                #     del inputs['token_type_ids']
                pad_token_id = self.tokenizer.eos_token_id
            else:
                pad_token_id = None

            if self.stop_sequences is not None:
                stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
                    stops=self.stop_sequences,
                    initial_length=len(inputs_token[0]), #IndexError: too many indices for tensor of dimension 2
                    tokenizer=self.tokenizer)])
            else:
                stopping_criteria = None

            input_length = len(inputs_token[0])
            with torch.no_grad():
                outputs = self.model.generate(
                    inputs_token,
                    max_new_tokens=self.max_new_tokens,
                    return_dict_in_generate=True,
                    output_scores=True,
                    output_attentions=True,
                    output_hidden_states=True,
                    temperature=temperature,
                    do_sample=False,
                    stopping_criteria=stopping_criteria,
                    pad_token_id=pad_token_id,
                    # do_branch=do_branch,
                    # topk=topk,
                )

            ###### Obtain seq likelihoods ###########
            generated_ids = outputs.sequences[:, input_length:]
            seq_score = torch.stack(outputs.scores, dim=1) #(btz, gen_seq_len, vocab_size)
            predictions = torch.log(F.softmax(seq_score, dim=-1))
            # index = (btz, gen_seq_len, 1) for torch.gather
            log_probs = torch.gather(predictions, dim=-1, index=generated_ids.unsqueeze(-1)).squeeze(-1)
            # token_mask = (btz, gen_seq_len), mask out special tokens in generated seq
            token_mask = torch.all(
                torch.stack(
                    [generated_ids != token_id
                    for token_id in self.tokenizer.all_special_ids
                    ],
                    dim=-1,),dim=-1,).long()
            num_tokens = token_mask.sum(dim=-1)
            seq_likelihoods = (log_probs * token_mask).sum(-1) / num_tokens
            seq_likelihoods = torch.exp(seq_likelihoods)
            #########################################

            ###### Check correctness ###########
            model_answers = self.tokenizer.batch_decode(
                generated_ids, skip_special_tokens=True
            )[0]
            for stop_word in self.stop_sequences:
                model_answers = model_answers.replace(stop_word, "")
            model_answers_list = []
            model_answers_list.append(model_answers)
            # if model_answers[0].startswith(self.tokenizer.eos_token):
            
            answers_correctness = check_answer_correctness(
                correct_answers=ground_truth,
                model_answers=model_answers_list,
            )
            #########################################

            ########## Obtain question embeddings ##########
            question_embeddings = embedding_model.encode(questions)
            #########################################

            del outputs

            return seq_likelihoods, question_embeddings, answers_correctness, model_answers

    def general_pridiction(self, batch, embedding_model, temperature, return_full=False, do_branch=False, topk=1):
        inputs = batch["prompt"]
        questions = batch["question"]

        if isinstance(inputs, list):
            inputs_token = self.tokenizer(inputs, return_tensors="pt")['input_ids'].to('cuda')
        if 'llama' in self.model_name.lower() or 'falcon' in self.model_name or 'mistral' in self.model_name.lower():
            # if 'token_type_ids' in inputs:  # Some HF models have changed.
            #     del inputs['token_type_ids']
            pad_token_id = self.tokenizer.eos_token_id
        else:
            pad_token_id = None

        if self.stop_sequences is not None:
            stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
                stops=self.stop_sequences,
                initial_length=len(inputs_token[0]), #IndexError: too many indices for tensor of dimension 2
                tokenizer=self.tokenizer)])
        else:
            stopping_criteria = None

        input_length = len(inputs_token[0])
        with torch.no_grad():
            outputs = self.model.generate(
                inputs_token,
                max_new_tokens=self.max_new_tokens,
                return_dict_in_generate=True,
                output_scores=True,
                output_attentions=True,
                output_hidden_states=True,
                temperature=temperature,
                do_sample=False,
                stopping_criteria=stopping_criteria,
                pad_token_id=pad_token_id,
                # do_branch=do_branch,
                # topk=topk,
            )

        ###### Obtain seq likelihoods ###########
        generated_ids = outputs.sequences[:, input_length:]
        
        ###### Check correctness ###########
        model_answers = self.tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True
        )[0]
        for stop_word in self.stop_sequences:
            model_answers = model_answers.replace(stop_word, "")
        model_answers_list = []
        model_answers_list.append(model_answers)
        # if model_answers[0].startswith(self.tokenizer.eos_token):
        
        #########################################

        del outputs

        return model_answers

# Original prediction
    def predict(self, system_prompt, questions, temperature, return_full=False, do_branch=False, topk=1):
            sys_inputs = self.tokenizer(system_prompt, return_tensors="pt")['input_ids'].to('cuda')

            if isinstance(questions, list):
                question_inputs = self.tokenizer(questions, return_tensors="pt")['input_ids'].to('cuda')
                inputs = torch.cat((sys_inputs, question_inputs), dim=1)

            if 'llama' in self.model_name.lower() or 'falcon' in self.model_name or 'mistral' in self.model_name.lower():
                # if 'token_type_ids' in inputs:  # Some HF models have changed.
                #     del inputs['token_type_ids']
                pad_token_id = self.tokenizer.eos_token_id
            else:
                pad_token_id = None

            if self.stop_sequences is not None:
                stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
                    stops=self.stop_sequences,
                    initial_length=len(inputs[0]), #IndexError: too many indices for tensor of dimension 2
                    tokenizer=self.tokenizer)])
            else:
                stopping_criteria = None

            
    ################  Change the generation parameters here ##############
            with torch.no_grad():
                outputs = self.model.generate(
                    inputs,
                    max_new_tokens=self.max_new_tokens,
                    return_dict_in_generate=True,
                    output_scores=True,
                    output_attentions=True,
                    output_hidden_states=True,
                    temperature=temperature,
                    do_sample=False,
                    stopping_criteria=stopping_criteria,
                    pad_token_id=pad_token_id,
                    do_branch=do_branch,
                    topk=topk,
                )
            len_sysprompt = len(sys_inputs[0])
            len_input = len(question_inputs[0])
            total_len = len(inputs[0])
            """
            attentions (tuple(tuple(torch.FloatTensor)), optional, 
            returned when output_attentions=True is passed or 
            config.output_attentions=True) — Tuple (one element for each generated token) 
            of tuples (one element for each layer of the decoder) of torch.FloatTensor 
            of shape (batch_size, num_heads, generated_length, sequence_length).
            """
            attention_scores = outputs.attentions

            transition_scores = self.model.compute_transition_scores(
                outputs.sequences, outputs.scores, normalize_logits=True)
    ###########################################################
                
            if len(outputs.sequences[0]) > self.token_limit:
                raise ValueError(
                    'Generation exceeding token limit %d > %d',
                    len(outputs.sequences[0]), self.token_limit)
            
            full_answer = self.tokenizer.decode(
                outputs.sequences[0], skip_special_tokens=True)
            input_text = self.tokenizer.decode(
                inputs[0], skip_special_tokens=True
            )
            # remove stop words from answer token
            answer_token = outputs.sequences[0][-(len(outputs.sequences[0])-len(inputs[0])):-1]
            
            if return_full:
                return full_answer
            
            # For some models, we need to remove the input_data from the answer.
            if full_answer.startswith(input_text):
                input_data_offset = len(input_text)
            else:
                raise ValueError('Have not tested this in a while.')
            
            # Remove input from answer.
            answer = full_answer[input_data_offset:]

            # Remove stop_words from answer.
            stop_at = len(answer)
            sliced_answer = answer
            if self.stop_sequences is not None:
                for stop in self.stop_sequences:
                    if answer.endswith(stop):
                        stop_at = len(answer) - len(stop)
                        sliced_answer = answer[:stop_at]
                        stop_token_length = len(self.tokenizer(stop, return_tensors="pt")['input_ids'][0])
                        sliced_score = transition_scores[0][:-1]
                        #This assertion need to be checked since it only works for \n
                        assert len(sliced_score) == len(answer_token)
                        break
                    else:
                        sliced_score = transition_scores[0]
                        break
                if not all([stop not in sliced_answer for stop in self.stop_sequences]):
                    error_msg = 'Error: Stop words not removed successfully!'
                    error_msg += f'Answer: >{answer}< '
                    error_msg += f'Sliced Answer: >{sliced_answer}<'
                    if 'falcon' not in self.model_name.lower():
                        raise ValueError(error_msg)
                    else:
                        logging.error(error_msg)

            # Remove whitespaces from answer (in particular from beginning.)
            sliced_answer = sliced_answer.strip()

            # Get the number of tokens until the stop word comes up.
            # Note: Indexing with `stop_at` already excludes the stop_token.
            # Note: It's important we do this with full answer, since there might be
            # non-trivial interactions between the input_data and generated part
            # in tokenization (particularly around whitespaces.)
            token_stop_index = self.tokenizer(full_answer[:input_data_offset + stop_at], return_tensors="pt")['input_ids'].shape[1]
            n_input_token = len(inputs[0])
            n_generated = token_stop_index - n_input_token

            if n_generated == 0:
                logging.warning('Only stop_words were generated. For likelihoods and embeddings, taking stop word instead.')
                n_generated = 1

            # Get the last hidden state (last layer) and the last token's embedding of the answer.
            # Note: We do not want this to be the stop token.

            # outputs.hidden_state is a tuple of len = n_generated_tokens.
            # The first hidden state is for the input tokens and is of shape
            #     (n_layers) x (batch_size, input_size, hidden_size).
            # (Note this includes the first generated token!)
            # The remaining hidden states are for the remaining generated tokens and is of shape
            #    (n_layers) x (batch_size, 1, hidden_size).

            # Note: The output embeddings have the shape (batch_size, generated_length, hidden_size).
            # We do not get embeddings for input_data! We thus subtract the n_tokens_in_input from
            # token_stop_index to arrive at the right output.

            if 'decoder_hidden_states' in outputs.keys():
                hidden = outputs.decoder_hidden_states
            else:
                hidden = outputs.hidden_states #Tuple of 1.generated tokens, 2.layers, 3. (batch_size, generated_length, hidden_size)

            if len(hidden) == 1:
                logging.warning(
                    'Taking first and only generation for hidden! '
                    'n_generated: %d, n_input_token: %d, token_stop_index %d, '
                    'last_token: %s, generation was: %s',
                    n_generated, n_input_token, token_stop_index,
                    self.tokenizer.decode(outputs['sequences'][0][-1]),
                    full_answer,
                    )
                last_input = hidden[0]

            elif ((n_generated - 1) >= len(hidden)):
                # If access idx is larger/equal.
                logging.error(
                    'Taking last state because n_generated is too large'
                    'n_generated: %d, n_input_token: %d, token_stop_index %d, '
                    'last_token: %s, generation was: %s, slice_answer: %s',
                    n_generated, n_input_token, token_stop_index,
                    self.tokenizer.decode(outputs['sequences'][0][-1]),
                    full_answer, sliced_answer
                    )
                last_input = hidden[-1]

            else:
                last_input = hidden[n_generated - 1]

            # Then access last layer for input
            last_layer = last_input[-1]
            # Then access last token in input.
            last_token_embedding = last_layer[:, -1, :].cpu() #最后一个token在最后一层的embedding

            # Get log_likelihoods.
            # outputs.scores are the logits for the generated token.
            # outputs.scores is a tuple of len = n_generated_tokens.
            # Each entry is shape (bs, vocabulary size).
            # outputs.sequences is the sequence of all tokens: input and generated.
            transition_scores = self.model.compute_transition_scores(
                outputs.sequences, outputs.scores, normalize_logits=True) 
            # batch_size*num_return_sequences, sequence_length
            # Transition_scores[0] only contains the scores for the first generated tokens.

            log_likelihoods = [score.item() for score in transition_scores[0]]
            if len(log_likelihoods) == 1:
                logging.warning('Taking first and only generation for log likelihood!')
                log_likelihoods = log_likelihoods
            else:
                log_likelihoods = log_likelihoods[:n_generated]

            if len(log_likelihoods) == self.max_new_tokens:
                logging.warning('Generation interrupted by max_token limit.')

            if len(log_likelihoods) == 0:
                raise ValueError

            return full_answer, sliced_answer, sliced_score, answer_token, log_likelihoods, last_token_embedding