# the following code comes from the repo: https://github.com/jlko/semantic_uncertainty


"""Implement semantic entropy."""
import os
import pickle
import logging

import numpy as np
import wandb
import torch
import torch.nn.functional as F

from transformers import AutoModelForSequenceClassification, AutoTokenizer

from uncertainty.models.huggingface_models import HuggingfaceModel
from uncertainty.utils import openai as oai
from uncertainty.utils import utils


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

class StoppingCriteriaSub(StoppingCriteria):
    """Stop generations when they match a particular text or token."""
    def __init__(self, stops, tokenizer, match_on='text', initial_length=None):
        super().__init__()
        self.stops = stops
        self.initial_length = initial_length
        self.tokenizer = tokenizer
        self.match_on = match_on
        if self.match_on == 'tokens':
            self.stops = [torch.tensor(self.tokenizer.encode(i)).to('cuda') for i in self.stops]
            print(self.stops)

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        del scores  # `scores` arg is required by StoppingCriteria but unused by us.
        for stop in self.stops:
            if self.match_on == 'text':
                generation = self.tokenizer.decode(input_ids[0][self.initial_length:], skip_special_tokens=False)
                match = stop in generation
            elif self.match_on == 'tokens':
                # Can be dangerous due to tokenizer ambiguities.
                match = stop in input_ids[0][-len(stop):]
            else:
                raise
            if match:
                return True
        return False


class BaseEntailment:
    def save_prediction_cache(self):
        pass


class EntailmentDeberta(BaseEntailment):
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v2-xlarge-mnli")
        self.model = AutoModelForSequenceClassification.from_pretrained(
            "microsoft/deberta-v2-xlarge-mnli").to(DEVICE)

    def check_implication(self, text1, text2, *args, **kwargs):
        inputs = self.tokenizer(text1, text2, return_tensors="pt").to(DEVICE)
        # The model checks if text1 -> text2, i.e. if text2 follows from text1.
        # check_implication('The weather is good', 'The weather is good and I like you') --> 1
        # check_implication('The weather is good and I like you', 'The weather is good') --> 2
        outputs = self.model(**inputs)
        logits = outputs.logits
        # Deberta-mnli returns `neutral` and `entailment` classes at indices 1 and 2.
        largest_index = torch.argmax(F.softmax(logits, dim=1))  # pylint: disable=no-member
        prediction = largest_index.cpu().item()
        if os.environ.get('DEBERTA_FULL_LOG', False):
            logging.info('Deberta Input: %s -> %s', text1, text2)
            logging.info('Deberta Prediction: %s', prediction)

        return prediction


class EntailmentLLM(BaseEntailment):

    entailment_file = 'entailment_cache.pkl'

    def __init__(self, entailment_cache_id, entailment_cache_only):
        self.prediction_cache = self.init_prediction_cache(entailment_cache_id)
        self.entailment_cache_only = entailment_cache_only

    def init_prediction_cache(self, entailment_cache_id):
        if entailment_cache_id is None:
            return dict()

        logging.info('Restoring prediction cache from %s', entailment_cache_id)

        api = wandb.Api()
        run = api.run(entailment_cache_id)
        run.file(self.entailment_file).download(
            replace=True, exist_ok=False, root=wandb.run.dir)

        with open(f'{wandb.run.dir}/{self.entailment_file}', "rb") as infile:
            return pickle.load(infile)

    def save_prediction_cache(self):
        # Write the dictionary to a pickle file.
        utils.save(self.prediction_cache, self.entailment_file)

    def check_implication(self, text1, text2, example=None):
        if example is None:
            raise ValueError
        prompt = self.equivalence_prompt(text1, text2, example['question'])

        logging.info('%s input: %s', self.name, prompt)

        hashed = oai.md5hash(prompt)
        if hashed in self.prediction_cache:
            logging.info('Restoring hashed instead of predicting with model.')
            response = self.prediction_cache[hashed]
        else:
            if self.entailment_cache_only:
                raise ValueError
            response = self.predict(prompt, temperature=0.02)
            self.prediction_cache[hashed] = response

        logging.info('%s prediction: %s', self.name, response)

        binary_response = response.lower()[:30]
        if 'entailment' in binary_response:
            return 2
        elif 'neutral' in binary_response:
            return 1
        elif 'contradiction' in binary_response:
            return 0
        else:
            logging.warning('MANUAL NEUTRAL!')
            return 1


class EntailmentGPT4(EntailmentLLM):

    def __init__(self, entailment_cache_id, entailment_cache_only):
        super().__init__(entailment_cache_id, entailment_cache_only)
        self.name = 'gpt-4'

    def equivalence_prompt(self, text1, text2, question):

        prompt = f"""We are evaluating answers to the question \"{question}\"\n"""
        prompt += "Here are two possible answers:\n"
        prompt += f"Possible Answer 1: {text1}\nPossible Answer 2: {text2}\n"
        prompt += "Does Possible Answer 1 semantically entail Possible Answer 2? Respond with entailment, contradiction, or neutral."""

        return prompt

    def predict(self, prompt, temperature):
        return oai.predict(prompt, temperature, model=self.name)


class EntailmentGPT35(EntailmentGPT4):

    def __init__(self, entailment_cache_id, entailment_cache_only):
        super().__init__(entailment_cache_id, entailment_cache_only)
        self.name = 'gpt-3.5'


class EntailmentGPT4Turbo(EntailmentGPT4):

    def __init__(self, entailment_cache_id, entailment_cache_only):
        super().__init__(entailment_cache_id, entailment_cache_only)
        self.name = 'gpt-4-turbo'


class EntailmentLlama(EntailmentLLM):

    def __init__(self, entailment_cache_id, entailment_cache_only, name):
        super().__init__(entailment_cache_id, entailment_cache_only)
        self.name = name
        self.model = HuggingfaceModel(
            name, stop_sequences='default', max_new_tokens=30)

    def equivalence_prompt(self, text1, text2, question):

        prompt = f"""We are evaluating answers to the question \"{question}\"\n"""
        prompt += "Here are two possible answers:\n"
        prompt += f"Possible Answer 1: {text1}\nPossible Answer 2: {text2}\n"
        prompt += "Does Possible Answer 1 semantically entail Possible Answer 2? Respond only with entailment, contradiction, or neutral.\n"""
        prompt += "Response:"""

        return prompt

    def predict(self, prompt, temperature):
        predicted_answer, _, _ = self.model.predict(prompt, temperature)
        return predicted_answer


def context_entails_response(context, responses, model):
    votes = []
    for response in responses:
        votes.append(model.check_implication(context, response))
    return 2 - np.mean(votes)


def get_semantic_ids(strings_list, model, strict_entailment=False, example=None):
    """Group list of predictions into semantic meaning."""

    def are_equivalent(text1, text2):

        implication_1 = model.check_implication(text1, text2, example=example)
        implication_2 = model.check_implication(text2, text1, example=example)  # pylint: disable=arguments-out-of-order
        assert (implication_1 in [0, 1, 2]) and (implication_2 in [0, 1, 2])

        if strict_entailment:
            semantically_equivalent = (implication_1 == 2) and (implication_2 == 2)

        else:
            implications = [implication_1, implication_2]
            # Check if none of the implications are 0 (contradiction) and not both of them are neutral.
            semantically_equivalent = (0 not in implications) and ([1, 1] != implications)

        return semantically_equivalent

    # Initialise all ids with -1.
    semantic_set_ids = [-1] * len(strings_list)
    # Keep track of current id.
    next_id = 0
    for i, string1 in enumerate(strings_list):
        # Check if string1 already has an id assigned.
        if semantic_set_ids[i] == -1:
            # If string1 has not been assigned an id, assign it next_id.
            semantic_set_ids[i] = next_id
            for j in range(i+1, len(strings_list)):
                # Search through all remaining strings. If they are equivalent to string1, assign them the same id.
                if are_equivalent(string1, strings_list[j]):
                    semantic_set_ids[j] = next_id
            next_id += 1

    assert -1 not in semantic_set_ids

    return semantic_set_ids


def logsumexp_by_id(semantic_ids, log_likelihoods, agg='sum_normalized'):
    """Sum probabilities with the same semantic id.

    Log-Sum-Exp because input and output probabilities in log space.
    """
    unique_ids = sorted(list(set(semantic_ids)))
    assert unique_ids == list(range(len(unique_ids)))
    log_likelihood_per_semantic_id = []

    for uid in unique_ids:
        # Find positions in `semantic_ids` which belong to the active `uid`.
        id_indices = [pos for pos, x in enumerate(semantic_ids) if x == uid]
        # Gather log likelihoods at these indices.
        id_log_likelihoods = [log_likelihoods[i] for i in id_indices]
        if agg == 'sum_normalized':
            # log_lik_norm = id_log_likelihoods - np.prod(log_likelihoods)
            log_lik_norm = id_log_likelihoods - np.log(np.sum(np.exp(log_likelihoods)))
            logsumexp_value = np.log(np.sum(np.exp(log_lik_norm)))
        else:
            raise ValueError
        log_likelihood_per_semantic_id.append(logsumexp_value)

    return log_likelihood_per_semantic_id


def predictive_entropy(log_probs):
    """Compute MC estimate of entropy.

    `E[-log p(x)] ~= -1/N sum_i log p(x_i)`, i.e. the average token likelihood.
    """

    entropy = -np.sum(log_probs) / len(log_probs)

    return entropy


def predictive_entropy_rao(log_probs):
    entropy = -np.sum(np.exp(log_probs) * log_probs)
    return entropy


def cluster_assignment_entropy(semantic_ids):
    """Estimate semantic uncertainty from how often different clusters get assigned.

    We estimate the categorical distribution over cluster assignments from the
    semantic ids. The uncertainty is then given by the entropy of that
    distribution. This estimate does not use token likelihoods, it relies soley
    on the cluster assignments. If probability mass is spread of between many
    clusters, entropy is larger. If probability mass is concentrated on a few
    clusters, entropy is small.

    Input:
        semantic_ids: List of semantic ids, e.g. [0, 1, 2, 1].
    Output:
        cluster_entropy: Entropy, e.g. (-p log p).sum() for p = [1/4, 2/4, 1/4].
    """

    n_generations = len(semantic_ids)
    counts = np.bincount(semantic_ids)
    probabilities = counts/n_generations
    assert np.isclose(probabilities.sum(), 1)
    entropy = - (probabilities * np.log(probabilities)).sum()
    return entropy




def generate_responses(model, prompt: str, num_generations: int = 10, temperature: float = 1.0) -> List[Tuple[str, List[float]]]:
    """
    Generate multiple responses for a given prompt.
    
    Args:
        model: The ModelWrapper instance to use for generation
        prompt: The input prompt
        num_generations: Number of responses to generate
        temperature: Temperature for generation (first response uses 0.1, others use this value)
    
    Returns:
        List of tuples (response_text, token_log_likelihoods)
    """
    responses = []
    
    for i in range(num_generations):
        # First generation uses low temperature (0.1), others use specified temperature
        gen_temp = 0.1 if i == 0 else temperature
        
        try:
            logging.debug(f"Calling model.predict with prompt: {prompt[:50]}...")
            result = model.predict(prompt, gen_temp, return_full=False)
            logging.debug(f"Model.predict returned {len(result)} values: {[type(x) for x in result]}")
            logging.debug(f"Result values: {result}")
            
            if len(result) == 3:
                predicted_answer, token_log_likelihoods, _ = result
            else:
                logging.error(f"Expected 3 values from model.predict, got {len(result)}: {result}")
                raise ValueError(f"Expected 3 values from model.predict, got {len(result)}")
            
            responses.append((predicted_answer, token_log_likelihoods))
            logging.debug(f"Generated response: {predicted_answer[:50]}...")
        except Exception as e:
            logging.warning(f"Error generating response {i+1} for prompt: {e}")
            import traceback
            traceback.print_exc()
            # If generation fails, skip this response
            continue
    
    return responses





def compute_semantic_entropy(responses: List[Tuple[str, List[float]]], entailment_model, question: str = None) -> float:
    """
    Compute semantic entropy for a list of responses.
    
    Args:
        responses: List of (response_text, token_log_likelihoods) tuples
        entailment_model: Model for checking semantic equivalence
        question: Optional question context for better entailment checking
    
    Returns:
        Semantic entropy score
    """
    if len(responses) < 2:
        logging.warning("Need at least 2 responses to compute semantic entropy")
        return 0.0
    
    # Extract response texts and log likelihoods
    response_texts = [r[0] for r in responses]
    log_likelihoods = [r[1] for r in responses]
    
    # Create example dict for entailment checking (if question provided)
    example = {'question': question} if question else None
    
    # Get semantic IDs (clusters of semantically equivalent responses)
    try:
        semantic_ids = get_semantic_ids(
            response_texts, 
            model=entailment_model,
            strict_entailment=True,  # Use strict entailment for better clustering
            example=example
        )
    except Exception as e:
        logging.warning(f"Error computing semantic IDs: {e}")
        return 0.0
    
    # Compute average log likelihood per response
    log_liks_agg = [np.mean(log_lik) for log_lik in log_likelihoods]
    
    # Aggregate log likelihoods by semantic cluster
    try:
        log_likelihood_per_semantic_id = logsumexp_by_id(semantic_ids, log_liks_agg, agg='sum_normalized')
        
        # Compute semantic entropy using Rao's entropy
        semantic_entropy = predictive_entropy_rao(log_likelihood_per_semantic_id)
        
        return semantic_entropy
    except Exception as e:
        logging.warning(f"Error computing semantic entropy: {e}")
        return 0.0


class ModelWrapper:
    """Wrapper class to adapt loaded models for semantic entropy calculations"""
    
    def __init__(self, model, tokenizer, model_name, max_new_tokens=2000, enable_thinking=True, stop_sequences=None, token_limit=10000):
        self.model = model
        self.tokenizer = tokenizer
        self.model_name = model_name
        self.max_new_tokens = max_new_tokens
        self.enable_thinking = enable_thinking
        self.stop_sequences = stop_sequences
        self.token_limit = token_limit

    def predict(self, input_data, temperature, return_full=False):
        # Handle different models with their specific tokenization approaches
        if 'llama' in self.model_name.lower():
            # Llama models use chat templates
            messages = [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": input_data},
            ]
            inputs = self.tokenizer.apply_chat_template(
                messages,
                add_generation_prompt=True,
                return_tensors="pt"
            ).to("cuda")
        elif 'qwen' in self.model_name.lower():
            # Qwen models use chat templates with enable_thinking
            messages = [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": input_data},
            ]
            text = self.tokenizer.apply_chat_template(
                messages,
                enable_thinking=self.enable_thinking,
                tokenize=False,
                add_generation_prompt=True
            )
            inputs = self.tokenizer([text], return_tensors="pt").to("cuda")
        else:
            # Mistral and other models use simple tokenization
            inputs = self.tokenizer(input_data, return_tensors="pt", padding=True, truncation=True).to("cuda")

        # Handle different input formats and terminators
        if 'llama' in self.model_name.lower():
            # Llama models return tensor directly, not dict
            input_length = inputs.shape[1]
            terminators = [
                self.tokenizer.eos_token_id,
                self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
            ]
            pad_token_id = self.tokenizer.eos_token_id
            eos_token_id = terminators
        elif 'qwen' in self.model_name.lower():
            # Qwen models return dict like Mistral
            if 'token_type_ids' in inputs:  # Some HF models have changed.
                del inputs['token_type_ids']
            input_length = inputs.input_ids.shape[1]  # Qwen uses .input_ids
            pad_token_id = self.tokenizer.eos_token_id
            eos_token_id = self.tokenizer.eos_token_id
        else:
            # Mistral and other models return dict
            if 'token_type_ids' in inputs:  # Some HF models have changed.
                del inputs['token_type_ids']
            input_length = len(inputs['input_ids'][0])
            pad_token_id = self.tokenizer.eos_token_id
            eos_token_id = self.tokenizer.eos_token_id

        if self.stop_sequences is not None:
            stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
                stops=self.stop_sequences,
                initial_length=input_length,
                tokenizer=self.tokenizer)])
        else:
            stopping_criteria = None

        logging.debug('temperature: %f', temperature)
        with torch.no_grad():
            if 'llama' in self.model_name.lower():
                # Llama models use tensor input directly
                outputs = self.model.generate(
                    inputs,
                    max_new_tokens=self.max_new_tokens,
                    return_dict_in_generate=True,
                    output_scores=True,
                    output_hidden_states=True,
                    temperature=temperature,
                    do_sample=True,
                    stopping_criteria=stopping_criteria,
                    pad_token_id=pad_token_id,
                    eos_token_id=eos_token_id,
                )
            else:
                # Mistral and Qwen models use dict input
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=self.max_new_tokens,
                    return_dict_in_generate=True,
                    output_scores=True,
                    output_hidden_states=True,
                    temperature=temperature,
                    do_sample=True,
                    stopping_criteria=stopping_criteria,
                    pad_token_id=pad_token_id,
                )

        if len(outputs.sequences[0]) > self.token_limit:
            raise ValueError(
                'Generation exceeding token limit %d > %d',
                len(outputs.sequences[0]), self.token_limit)

        full_answer = self.tokenizer.decode(
            outputs.sequences[0], skip_special_tokens=True)

        if return_full:
            return full_answer

        # Simple approach: remove input data from the beginning if it starts with it
        if full_answer.startswith(input_data):
            input_data_offset = len(input_data)
        else:
            # If it doesn't start with input_data, assume the whole thing is generated
            input_data_offset = 0

        # Remove input from answer to get just the generated part
        answer = full_answer[input_data_offset:]

        # No stop word removal - use the full generated answer
        sliced_answer = answer

        # Remove whitespaces from answer (in particular from beginning.)
        sliced_answer = sliced_answer.strip()

        # Calculate the number of generated tokens - handle different input formats
        token_stop_index = self.tokenizer(full_answer, return_tensors="pt")['input_ids'].shape[1]
        if 'llama' in self.model_name.lower():
            n_input_token = input_length  # Llama uses tensor directly
        elif 'qwen' in self.model_name.lower():
            n_input_token = input_length  # Qwen uses input_length
        else:
            n_input_token = len(inputs['input_ids'][0])  # Mistral uses dict
        n_generated = token_stop_index - n_input_token
        
        # Debug token calculation
        logging.debug(f"Token calculation: token_stop_index={token_stop_index}, n_input_token={n_input_token}, n_generated={n_generated}")
        logging.debug(f"Full answer length: {len(full_answer)}, sliced answer length: {len(sliced_answer)}")

        if n_generated == 0:
            logging.warning('No tokens were generated. For likelihoods and embeddings, using minimum of 1.')
            n_generated = 1

        # Get the last hidden state (last layer) and the last token's embedding of the answer.
        # Note: We do not want this to be the stop token.

        # outputs.hidden_state is a tuple of len = n_generated_tokens.
        # The first hidden state is for the input tokens and is of shape
        #     (n_layers) x (batch_size, input_size, hidden_size).
        # (Note this includes the first generated token!)
        # The remaining hidden states are for the remaining generated tokens and is of shape
        #    (n_layers) x (batch_size, 1, hidden_size).

        # Note: The output embeddings have the shape (batch_size, generated_length, hidden_size).
        # We do not get embeddings for input_data! We thus subtract the n_tokens_in_input from
        # token_stop_index to arrive at the right output.

        if 'decoder_hidden_states' in outputs.keys():
            hidden = outputs.decoder_hidden_states
        else:
            hidden = outputs.hidden_states

        # Debug hidden states information
        logging.debug(f"Hidden states info: len(hidden)={len(hidden)}, n_generated={n_generated}, n_input_token={n_input_token}")
        
        if len(hidden) == 1:
            logging.warning(
                'Taking first and only generation for hidden! '
                'n_generated: %d, n_input_token: %d, token_stop_index %d, '
                'last_token: %s, generation was: %s',
                n_generated, n_input_token, token_stop_index,
                self.tokenizer.decode(outputs['sequences'][0][-1]),
                full_answer,
                )
            last_input = hidden[0]
        elif n_generated <= 0:
            # Handle case where n_generated is 0 or negative
            logging.warning(f"n_generated is {n_generated}, using first hidden state")
            last_input = hidden[0]
        elif ((n_generated - 1) >= len(hidden)):
            # If access idx is larger/equal.
            logging.error(
                'Taking last state because n_generated is too large'
                'n_generated: %d, n_input_token: %d, token_stop_index %d, '
                'last_token: %s, generation was: %s, slice_answer: %s',
                n_generated, n_input_token, token_stop_index,
                self.tokenizer.decode(outputs['sequences'][0][-1]),
                full_answer, sliced_answer
                )
            last_input = hidden[-1]
        else:
            last_input = hidden[n_generated - 1]

        # Then access last layer for input
        last_layer = last_input[-1]
        # Then access last token in input.
        last_token_embedding = last_layer[:, -1, :].cpu()

        # Get log_likelihoods.
        # outputs.scores are the logits for the generated token.
        # outputs.scores is a tuple of len = n_generated_tokens.
        # Each entry is shape (bs, vocabulary size).
        # outputs.sequences is the sequence of all tokens: input and generated.
        transition_scores = self.model.compute_transition_scores(
            outputs.sequences, outputs.scores, normalize_logits=True)
        # Transition_scores[0] only contains the scores for the first generated tokens.

        log_likelihoods = [score.item() for score in transition_scores[0]]
        if len(log_likelihoods) == 1:
            logging.warning('Taking first and only generation for log likelihood!')
            log_likelihoods = log_likelihoods
        else:
            log_likelihoods = log_likelihoods[:n_generated]

        if len(log_likelihoods) == self.max_new_tokens:
            logging.warning('Generation interrupted by max_token limit.')

        if len(log_likelihoods) == 0:
            raise ValueError

        logging.debug(f"Returning from predict: sliced_answer type={type(sliced_answer)}, log_likelihoods type={type(log_likelihoods)}, last_token_embedding type={type(last_token_embedding)}")
        return sliced_answer, log_likelihoods, last_token_embedding
