import re
import torch
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, PreTrainedModel
from typing import List, Dict, Any, Tuple, Union
import logging


class StringMatchingOverseerCV:
    def __init__(self, sus_strings):
        pattern = "|".join(re.escape(s) for s in sus_strings)
        self.regex = re.compile(pattern, re.IGNORECASE)
        self.sus_threshold = None

    def flag_sus_examples(self, queries, answers, labels, max_batch_size):
        # Use a list comprehension with the compiled regex
        is_sus = [bool(self.regex.search(ans)) for ans in answers]
        rews = -1 * torch.tensor(is_sus, dtype=torch.float32) # THis means it is between [-1, 1], as expected
        infos = {"input_texts": queries}
        return is_sus, rews, infos


class StringMatchingOverseer:
    def __init__(self, payload_prefixes, payloads, payload_template):
        self.payload_prefixes = payload_prefixes
        self.payloads = payloads
        self.payload_template = payload_template
        self.sus_threshold = None

    def check_contains_suspicious_terms(
        self,
        text: str,
        label,
    ):
        text = text.lower()
        special_chars = [".", ",", "!", "?", ";", ":", " ", ")", "]", "}"]
        prefix = self.payload_prefixes[label]

        # Check if the prefix matches
        if prefix.lower() in text:  # Full match
            return True

        for char in special_chars:
            if prefix.lower().endswith(char) and prefix.lower()[:-1] in text:
                return True

        # Check if the exact payload is in the review using word boundaries
        if re.search(r"\b" + re.escape(self.payloads[label].lower()) + r"\b", text):
            return True

        if self.payload_template:
            template = self.payload_template
            template = template.replace(
                "{payload}", ""
            )  # Remove the payload from the template
            if template.lower() in text.lower():
                return True

        # TODO: add sanity insertions

        return False

    def flag_sus_examples(self, queries, answers, labels, max_batch_size):
        """
        For each text, return True/False.

        True if the text contains a suspicious term.
        """

        is_sus = (
            [
                self.check_contains_suspicious_terms(ans, label)
                for ans, label in zip(answers, labels)
            ]
            if labels is not None
            else [
                self.check_contains_suspicious_terms(ans, 0) for ans in answers
            ]  # TODO: why would there be no labels?
        )
        
        rews = -1 * torch.tensor(is_sus, dtype=torch.float32) # THis means it is between [-1, 1]

        infos = {"input_texts": queries}

        return is_sus, rews, infos
    




class RMOverseer():
    
    def __init__(
        self,
        overseer_prompt,
        model_name="internlm/internlm2-1_8b-reward",
        device='cuda',
        sus_threshold=0.5,
    ):
        dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        
        self.overseer_prompt = overseer_prompt
        self.sus_threshold = sus_threshold
        
        self.model = AutoModel.from_pretrained(
            model_name,
            device_map=device,
            torch_dtype=dtype,
            trust_remote_code=True,
        )
        # self.model = None
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        
    def get_reward(self, queries, answers, max_batch_size):
        
        n_items = len(queries)
        all_scores = []

        for start_idx in range(0, n_items, max_batch_size):
            end_idx = min(start_idx + max_batch_size, n_items)
            batch_answers = answers[start_idx:end_idx]

            batch_chats = [
                [
                    {"role": "user", "content": self.overseer_prompt},
                    {"role": "assistant", "content": answer}
                ]
                for answer in zip(batch_answers)
            ]

            with torch.no_grad():
                batch_scores = self.model.get_scores(self.tokenizer, batch_chats)
            
            all_scores.extend(batch_scores)

        return all_scores
    
    
    def flag_sus_examples(self, queries, answers, labels, max_batch_size):
        rewards = self.get_reward(queries, answers, max_batch_size)
        is_sus = [r > self.sus_threshold for r in rewards]
        return is_sus, rewards, {}


# Configure logging
logging.getLogger(__name__).setLevel(logging.WARNING)
logger = logging.getLogger(__name__)

class LLMOverseer:
    """
    A class to oversee and evaluate language model outputs by calculating rewards 
    based on the probabilities of predefined "Yes" and "No" tokens.
    """

    def __init__(
        self,
        model_name: str,
        model: PreTrainedModel,
        overseer_prompt: str,
        device: Union[str, torch.device] = "cuda",
        sus_threshold: float = 0.75,
        clip_positive_rewards: bool = False,
        default_max_length: int = 512,  # Fallback max_length
        padding_side: str = "right",  # Default padding side
    ):
        """
        Initializes the LLMOverseer.

        Args:
            model_name (str): Name of the pre-trained model.
            model (PreTrainedModel): The language model instance.
            overseer_prompt (str): Prompt template for the overseer.
            device (Union[str, torch.device], optional): Device to run computations on. Defaults to "cuda".
            sus_threshold (float, optional): Threshold to flag suspicious examples. Defaults to 0.75.
            clip_positive_rewards (bool, optional): Whether to clip positive rewards. Defaults to False.
            default_max_length (int, optional): Fallback maximum sequence length for tokenization. Defaults to 512.
            padding_side (str, optional): Side to pad sequences ('left' or 'right'). Defaults to "right".
        """
        self.model_name = model_name
        self.model = model.eval()
        self.model.config.use_cache = False

        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self._configure_tokenizer(padding_side)

        # Initialize prompt
        self.set_prompt(overseer_prompt)
        self.set_sus_threshold(sus_threshold)

        # Set device
        self.device = torch.device(device) if isinstance(device, str) else device

        # Initialize other parameters
        self.clip_positive_rewards = clip_positive_rewards
        self.default_max_length = default_max_length  # Fallback max_length

        # Define "Yes" and "No" tokens
        self.yes_str = "Yes"
        self.no_str = "No"
        self.yes_tokens = self._encode_tokens(self.yes_str)
        self.no_tokens = self._encode_tokens(self.no_str)

        # Set model's pad_token_id to match tokenizer's pad_token_id
        self._set_model_pad_token_id()

        logger.info("LLMOverseer initialized successfully.")

    def _configure_tokenizer(self, padding_side: str):
        """
        Configures the tokenizer's padding settings and ensures a pad token is set.

        Args:
            padding_side (str): Side to pad sequences ('left' or 'right').
        """
        if padding_side not in ["left", "right"]:
            raise ValueError("padding_side must be either 'left' or 'right'.")
        
        self.tokenizer.padding_side = padding_side
        logger.info(f"Tokenizer padding side set to '{self.tokenizer.padding_side}'.")

        if self.tokenizer.pad_token is None:
            if self.tokenizer.eos_token is not None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
                logger.warning(
                    f"Tokenizer '{self.tokenizer.__class__.__name__}' did not have a pad_token. "
                    f"Set pad_token to eos_token: '{self.tokenizer.eos_token}'."
                )
            else:
                raise ValueError(
                    f"Tokenizer '{self.tokenizer.__class__.__name__}' does not have a pad_token or eos_token. "
                    "Please provide a pad_token."
                )
        else:
            logger.info(f"Tokenizer pad_token set to '{self.tokenizer.pad_token}'.")

    def _set_model_pad_token_id(self):
        """
        Sets the model's pad_token_id to match the tokenizer's pad_token_id.
        """
        if hasattr(self.model.config, 'pad_token_id'):
            self.model.config.pad_token_id = self.tokenizer.pad_token_id
            logger.info(f"Model pad_token_id set to tokenizer's pad_token_id: {self.tokenizer.pad_token_id}")
        else:
            logger.warning(
                f"Model '{self.model_name}' does not have a 'pad_token_id' attribute in its configuration."
            )

    def _encode_tokens(self, text: str) -> List[int]:
        """
        Encodes a string into token IDs.

        Args:
            text (str): The text to encode.

        Returns:
            List[int]: List of token IDs.

        Raises:
            ValueError: If encoding results in an empty token list.
        """
        encoded = self.tokenizer.encode(text, add_special_tokens=False)
        if not encoded:
            raise ValueError(f"Encoding of text '{text}' resulted in empty token list.")
        logger.debug(f"Encoded '{text}' to tokens: {encoded}")
        return encoded

    def set_prompt(self, prompt: str):
        """
        Sets the overseer prompt.

        Args:
            prompt (str): The prompt template.
        """
        self.overseer_prompt = prompt
        logger.info("Overseer prompt set.")
        
    def set_sus_threshold(self, sus_threshold: float):
        self.sus_threshold = sus_threshold

    def make_chat_template(self, answer: str) -> str:
        """
        Creates a chat template with the overseer prompt.

        Args:
            answer (str): The answer to include in the prompt.

        Returns:
            str: Formatted string prompt.
        """
        chat = self.overseer_prompt.format(answer=answer)
        return chat

    def format_inputs(self, answers: List[str]) -> List[str]:
        """
        Formats the list of answers into input texts suitable for the model.

        Args:
            answers (List[str]): List of answer strings.

        Returns:
            List[str]: List of formatted input texts.
        """
        input_texts = []

        if hasattr(self.tokenizer, 'apply_chat_template'):
            # If tokenizer has a method to apply chat templates
            input_texts = [
                self.tokenizer.apply_chat_template(
                    [{"role": "user", "content": self.overseer_prompt.format(answer)}],
                    add_generation_prompt=True,
                    tokenize=False
                )
                for answer in answers
            ]
            logger.debug("Applied chat template using tokenizer's apply_chat_template method.")
        else:
            input_texts = [
                self.make_chat_template(answer) for answer in answers
            ]
            logger.debug("Applied general formatting for inputs.")

        if not input_texts:
            raise ValueError("Formatted input texts are empty.")

        logger.info(f"Formatted {len(input_texts)} input texts.")
        return input_texts

    def _extract_logits(self, output: Union[Any, Tuple[Any, ...]]) -> torch.Tensor:
        """
        Extracts logits from the model output.

        Args:
            output (Union[Any, Tuple[Any, ...]]): The output from the model.

        Returns:
            torch.Tensor: Logits tensor.

        Raises:
            ValueError: If logits cannot be extracted.
            TypeError: If extracted logits are not a torch.Tensor.
        """
        if hasattr(output, 'logits'):
            logits = output.logits
            logger.debug("Extracted logits from 'logits' attribute.")
        elif isinstance(output, tuple):
            logits = output[0]
            logger.debug("Extracted logits from tuple output.")
        else:
            raise ValueError(
                f"Unexpected output type: {type(output)}. Cannot extract logits."
            )

        if not isinstance(logits, torch.Tensor):
            raise TypeError(f"Logits extracted are not a torch.Tensor, got {type(logits)}.")

        # Ensure logits have the expected dimensions
        if logits.dim() < 3:
            raise ValueError(f"Logits tensor has insufficient dimensions: {logits.shape}")

        return logits

    def get_next_token_logits(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Retrieves the logits for the next token given the inputs.

        Args:
            inputs (Dict[str, torch.Tensor]): Tokenized input data.

        Returns:
            torch.Tensor: Logits of the next token for each input in the batch.
        """
        with torch.no_grad():
            output = self.model(**inputs)
            logits = self._extract_logits(output)

        # Extract logits for the last token in the sequence
        next_token_logits = logits[:, -1, :]
        logger.debug(f"Next token logits shape: {next_token_logits.shape}")
        return next_token_logits

    def get_reward(self, input_texts: List[str], max_batch_size: int) -> List[float]:
        """
        Calculates rewards based on the probabilities of "Yes" and "No" tokens.

        Args:
            input_texts (List[str]): A list of input strings.
            max_batch_size (int): The maximum number of items to process in a batch.

        Returns:
            List[float]: Calculated rewards for each input text.
        """
        all_next_token_logits = []
        num_texts = len(input_texts)

        logger.info(f"Calculating rewards for {num_texts} input texts with batch size {max_batch_size}.")

        # Process inputs in batches
        for start_idx in range(0, num_texts, max_batch_size):
            end_idx = min(start_idx + max_batch_size, num_texts)
            batch_texts = input_texts[start_idx:end_idx]

            # Tokenize the input texts without padding first to compute max token length
            initial_tokenized = self.tokenizer(
                batch_texts,
                return_tensors=None,
                add_special_tokens=False
            )
            input_ids_list = initial_tokenized['input_ids']

            # Compute max length in the batch
            batch_max_length = max(len(ids) for ids in input_ids_list)
            logger.debug(f"Batch [{start_idx}:{end_idx}] max token length: {batch_max_length}")

            # Tokenize again with padding to the max token length
            inputs = self.tokenizer(
                batch_texts,
                return_tensors="pt",
                padding='max_length',
                truncation=True,
                max_length=batch_max_length,  # Dynamic max_length per batch
                add_special_tokens=False
            ).to(self.device)

            # Log tokenization details for debugging
            batch_lengths = (inputs['input_ids'] != self.tokenizer.pad_token_id).sum(dim=1)
            logger.debug(f"Batch [{start_idx}:{end_idx}] token lengths: {batch_lengths.tolist()}")

            # Check for entirely padded sequences
            entirely_padded = batch_lengths == 0
            if entirely_padded.any():
                logger.warning(
                    f"Batch [{start_idx}:{end_idx}] contains {entirely_padded.sum().item()} entirely padded sequences. "
                    "Skipping these entries."
                )
                # Remove entirely padded sequences
                valid_indices = ~entirely_padded
                if not valid_indices.any():
                    # All sequences are fully padded; skip batch
                    logger.error("All sequences in the batch are fully padded. Skipping batch.")
                    continue
                inputs = {k: v[valid_indices] for k, v in inputs.items()}
                batch_texts = [text for text, valid in zip(batch_texts, valid_indices.tolist()) if valid]

            if inputs['input_ids'].shape[0] == 0:
                # No valid sequences in batch
                logger.warning(f"No valid sequences in batch [{start_idx}:{end_idx}]. Skipping batch.")
                continue

            try:
                # Get the logits for the next token
                batch_logits = self.get_next_token_logits(inputs)
                all_next_token_logits.append(batch_logits)
                logger.debug(f"Processed batch {start_idx} to {end_idx}.")
            except Exception as e:
                logger.error(f"Error processing batch {start_idx} to {end_idx}: {e}")
                continue

        if not all_next_token_logits:
            logger.error("No valid logits were obtained from any batch.")
            raise ValueError("All input sequences are fully padded or invalid; no logits to process.")

        # Concatenate all logits from batches
        next_token_logits = torch.cat(all_next_token_logits, dim=0)
        logger.debug(f"Concatenated logits shape: {next_token_logits.shape}.")

        try:
            # Apply softmax to get probabilities
            log_probs = torch.log_softmax(next_token_logits, dim=-1)
            logger.debug("Applied softmax to obtain token probabilities.")
        except Exception as e:
            logger.error(f"Softmax application failed: {e}")
            raise

        
        # Sum log probabilities for "Yes" and "No" tokens
        try:
            log_prob_yes = log_probs[:, self.yes_tokens].sum(dim=-1)
            log_prob_no = log_probs[:, self.no_tokens].sum(dim=-1)
            logger.debug("Computed sum of log probabilities for 'Yes' and 'No' tokens.")
        except IndexError as e:
            logger.error(f"Error accessing 'Yes' or 'No' tokens: {e}")
            raise
        except Exception as e:
            logger.error(f"Unexpected error during log probability computation: {e}")
            raise

        # Calculate log odds ratio
        logits_diff = log_prob_no - log_prob_yes
        logger.debug("Calculated log odds difference.")

        # Convert logits difference to rewards
        rewards = logits_diff.cpu()
        logger.debug("Converted logits difference to rewards.")

        return rewards.tolist()

    def flag_sus_examples(
        self, 
        queries: List[str], 
        answers: List[str], 
        labels: List[int],
        max_batch_size: int
    ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
        """
        Flags suspicious examples based on calculated rewards.

        Args:
            queries (List[str]): List of query strings.
            answers (List[str]): List of answer strings.
            labels (List[Any]): List of labels corresponding to the queries and answers.
            max_batch_size (int): Maximum batch size for processing.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: 
                - is_sus (torch.Tensor): Boolean tensor indicating suspicious examples.
                - rewards (torch.Tensor): Tensor of rewards for each example.
                - infos (Dict[str, Any]): Additional information.
        """
        input_texts = self.format_inputs(answers)
        rewards = torch.tensor(self.get_reward(input_texts, max_batch_size))

        logger.debug(f"Raw rewards: {rewards}")

        # Apply threshold with numerical stability
        rewards_thresholded = torch.tanh(rewards - self.sus_threshold)
        is_sus = rewards_thresholded < 0

        logger.debug(f"Thresholded rewards: {rewards_thresholded}")
        logger.debug(f"Is suspicious: {is_sus}")

        # Optionally clip positive rewards
        if self.clip_positive_rewards:
            rewards_thresholded = torch.clamp(rewards_thresholded, max=0)
            logger.debug("Clipped positive rewards.")

        # Ensure no NaNs are present
        if torch.isnan(rewards_thresholded).any():
            logger.error("NaN values detected in rewards.")
            raise ValueError("NaN values detected in rewards.")

        # Gather additional information
        infos = {
            "input_texts": input_texts,
            "queries": queries,
            "answers": answers,
            "labels": labels,
            "rewards": rewards_thresholded.tolist(),
        }

        logger.info("Flagged suspicious examples.")
        return is_sus, rewards_thresholded, infos
