import logging
import os
from typing import Any, Dict, List

import torch
import torch.nn.functional as F
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    StoppingCriteria,
)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

STOP_SEQUENCES = [
    "\n",
    "\n\n",
    ".\n",
    "Question:",
    "Context:",
    "Answer:",
    "Q:",
    "A:",
    "question:",
    "context:",
    "answer:",
    "A possible answer",
]


class StoppingCriteriaSub(StoppingCriteria):  # from Semantic Entropy repo
    """Stop generations when they match a particular text or token."""

    def __init__(self, stops, tokenizer, match_on="text", initial_length=None):
        """Initialize stopping criteria for generation.

        Args:
            stops: List of token ids or strings used to end generation.
            tokenizer: Tokenizer used to decode or encode ``stops``.
            match_on (str, optional): Whether to match on raw ``'text'`` or ``'tokens'``.
            initial_length (int, optional): Starting index of generated tokens to evaluate.
        """
        super().__init__()
        self.stops = stops
        self.initial_length = initial_length
        self.tokenizer = tokenizer
        self.match_on = match_on
        if self.match_on == "tokens":
            self.stops = [
                torch.tensor(self.tokenizer.encode(i)).to("cuda") for i in self.stops
            ]
            print(self.stops)

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        """Return ``True`` when a stop sequence is encountered.

        Args:
            input_ids (torch.LongTensor): Current generated token ids.
            scores (torch.FloatTensor): Token scores supplied by the generation loop (unused).

        Returns:
            bool: ``True`` if any stop criterion is satisfied, otherwise ``False``.
        """
        # del scores  # `scores` arg is required by StoppingCriteria but unused by us.
        for stop in self.stops:
            if self.match_on == "text":
                generation = self.tokenizer.decode(
                    input_ids[0][self.initial_length :], skip_special_tokens=False
                )
                match = stop in generation
            elif self.match_on == "tokens":
                # Can be dangerous due to tokenizer ambiguities.
                match = stop in input_ids[0][-len(stop) :]
            else:
                raise
            if match:
                return True
        return False


def truncate_text(stop_sequences: List[str], text: str) -> str:
    """Truncate text at the first occurrence of any stop sequence."""
    earliest_stop = len(text)
    for seq in stop_sequences:
        pos = text.find(seq)
        if pos != -1 and pos < earliest_stop:
            earliest_stop = pos
    return text[:earliest_stop]


def truncate_output(
    text: str,
    token_scores: List,
    tokenizer: AutoTokenizer,
    stop_sequences: List[str],
) -> Dict[str, Any]:
    """Truncate text and align token scores with the truncated content.

    Args:
        text (str): Generated text to truncate.
        token_scores (List): Token-level scores aligned with the original text.
        tokenizer (AutoTokenizer): Tokenizer used to re-tokenize the truncated text.
        stop_sequences (List[str], optional): Sequences that signal truncation points.

    Returns:
        Dict[str, Any]: Dictionary with truncated ``text``, ``token_scores``, and ``text_encoded`` ids.
    """

    # 1. Truncate text at the first occurrence of any of the stop sequences
    text = truncate_text(stop_sequences, text)
    # 2. Encode the truncated text into token_ids without special tokens
    text_encoded = tokenizer.encode(text, add_special_tokens=False)
    # 3. Retrieve the token score (log likelihood) for each token in the trunacted text
    token_scores = token_scores[: len(text_encoded)]

    # Check if token_scores and text_encoded have the same length.
    # This should normally only happen in the following case:
    # Token_scores is cut off based on stopping criteria in causalLM. In case that decoding and encoding
    # creates a tokennization mismatch we could have the case that the text encoded is longer.
    if len(token_scores) != len(text_encoded):
        logging.warning(
            f"Token scores length {len(token_scores)} does not match text length {len(text_encoded)}. "
            "Truncating token scores to match text length."
        )
        min_length = min(len(token_scores), len(text_encoded))
        token_scores = token_scores[:min_length]
        text_encoded = text_encoded[:min_length]

    return {"text": text, "token_scores": token_scores, "text_encoded": text_encoded}


# ---# Code from Semantic Entropy repo #---#


class EntailmentDeberta():
    def __init__(self,selector_model=None):
        self.tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v2-xlarge-mnli")
        self.model = AutoModelForSequenceClassification.from_pretrained(
            "microsoft/deberta-v2-xlarge-mnli").to(DEVICE)
        self.selector_model = selector_model

    def check_implication(self, text1, text2, *args, **kwargs):


        if self.selector_model == 'comma':
            #check for exact string match
            if text1.strip() == text2.strip():
                return 2  #entailment
            else:
                return 0  #contradiction

        inputs = self.tokenizer(text1, text2, return_tensors="pt").to(DEVICE)
        # The model checks if text1 -> text2, i.e. if text2 follows from text1.
        # check_implication('The weather is good', 'The weather is good and I like you') --> 1
        # check_implication('The weather is good and I like you', 'The weather is good') --> 2
        outputs = self.model(**inputs)
        logits = outputs.logits
        # Deberta-mnli returns `neutral` and `entailment` classes at indices 1 and 2.
        largest_index = torch.argmax(F.softmax(logits, dim=1))  # pylint: disable=no-member
        prediction = largest_index.cpu().item()
        if os.environ.get("DEBERTA_FULL_LOG", False):
            logging.info("Deberta Input: %s -> %s", text1, text2)
            logging.info("Deberta Prediction: %s", prediction)

        return prediction


def get_semantic_ids(
    strings_list: list[str],
    model,
    strict_entailment: bool = False,
    example: str | None = None,
) -> list[int]:
    """Group list of predictions into semantic meaning."""

    def are_equivalent(text1, text2):
        implication_1 = model.check_implication(text1, text2, example=example)
        implication_2 = model.check_implication(text2, text1, example=example)  # pylint: disable=arguments-out-of-order
        assert (implication_1 in [0, 1, 2]) and (implication_2 in [0, 1, 2])

        if strict_entailment:
            semantically_equivalent = (implication_1 == 2) and (implication_2 == 2)

        else:
            implications = [implication_1, implication_2]
            # Check if none of the implications are 0 (contradiction) and not both of them are neutral.
            semantically_equivalent = (0 not in implications) and (
                [1, 1] != implications
            )

        return semantically_equivalent

    # Initialise all ids with -1.
    semantic_set_ids = [-1] * len(strings_list)
    # Keep track of current id.
    next_id = 0
    for i, string1 in enumerate(strings_list):
        # Check if string1 already has an id assigned.
        if semantic_set_ids[i] == -1:
            # If string1 has not been assigned an id, assign it next_id.
            semantic_set_ids[i] = next_id
            for j in range(i + 1, len(strings_list)):
                # Search through all remaining strings. If they are equivalent to string1, assign them the same id.
                if are_equivalent(string1, strings_list[j]):
                    semantic_set_ids[j] = next_id
            next_id += 1

    assert -1 not in semantic_set_ids

    return semantic_set_ids
