import gc
from typing import Any, Dict, Optional, Tuple

import torch
from custom_dreamy.i_runner import IRunner

from eliciting_contexts.utils.text import get_fixed_positions


# Helper functions for GPU memory management
def clear_gpu_memory():
    """
    Clear GPU cache and run garbage collection to free up memory.
    Should be called between major processing steps.
    """
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()


def get_gpu_memory_usage():
    """
    Get current GPU memory usage in MB.

    Returns:
        dict: Dictionary with allocated and reserved memory info
    """
    if not torch.cuda.is_available():
        return {"allocated": 0, "reserved": 0}

    return {
        "allocated": torch.cuda.memory_allocated() / 1024**2,  # MB
        "reserved": torch.cuda.memory_reserved() / 1024**2,  # MB
    }


def log_memory_usage(step_name):
    """
    Log the current GPU memory usage with a step name.

    Args:
        step_name (str): Name of the current processing step
    """
    memory_stats = get_gpu_memory_usage()
    print(
        f"[Memory at {step_name}] Allocated: {memory_stats['allocated']:.2f} MB, "
        f"Reserved: {memory_stats['reserved']:.2f} MB"
    )


def release_tensor(tensor):
    """
    Explicitly move a tensor to CPU and delete it to free GPU memory.

    Args:
        tensor: PyTorch tensor to release
    """
    if tensor is not None:
        if hasattr(tensor, "device") and str(tensor.device) != "cpu":
            tensor = tensor.detach().cpu()
        del tensor


def format_chat(tokenizer, system_prompt, user_message, return_token_type_map=True):
    """
    Format a chat prompt for a specific model using its tokenizer.
    IF no system prompt use user prompt followed by assistant saying "Understood."

    Args:
        tokenizer: The model's tokenizer
        system_prompt (str): The system instructions
        user_message (str): The user's message
        return_token_type_map (bool): Whether to return token type mappings

    Returns:
        str or tuple: Formatted prompt text or (formatted_prompt, token_type_map)
    """
    # Most modern tokenizers have built-in chat templates

    try:
        # Try to format using the chat template
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_message},
        ]
        formatted_prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,  # This adds the assistant token at the end
        )
    except Exception as e:
        if "System role not supported" in str(e):
            messages = [
                {"role": "user", "content": system_prompt},
                {"role": "assistant", "content": "Understood."},
                {"role": "user", "content": user_message},
            ]
            formatted_prompt = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,  # This adds the assistant token at the end
            )
        else:
            # Re-raise other errors
            raise e

    if not return_token_type_map:
        return formatted_prompt

    # Create a mapping of token types
    tokens = tokenizer.encode(formatted_prompt, return_tensors="pt")[0]

    # Tokenize individual components to identify their positions
    system_tokens = tokenizer.encode(system_prompt, add_special_tokens=False)
    user_tokens = tokenizer.encode(user_message, add_special_tokens=False)

    # Convert tokens to list for easier processing
    tokens_list = tokens.tolist()

    # Initialize token type map
    token_type_map = ["special"] * len(tokens_list)

    # Find system tokens in the full sequence
    for i in range(len(tokens_list) - len(system_tokens) + 1):
        if tokens_list[i : i + len(system_tokens)] == system_tokens:
            for j in range(len(system_tokens)):
                token_type_map[i + j] = "system"

    # Find user tokens in the full sequence
    for i in range(len(tokens_list) - len(user_tokens) + 1):
        if tokens_list[i : i + len(user_tokens)] == user_tokens:
            for j in range(len(user_tokens)):
                token_type_map[i + j] = "user"

    return formatted_prompt, tokens_list, token_type_map


def get_model_response(
    model,
    tokenizer,
    user_message,
    system_message="You are a helpful AI assistant.",
    max_tokens=100,
    temperature=0.7,
    prepend_bos=True,
    return_input=False,
    verbose=False,
):
    """
    Get a response from the model for a given user message.

    Args:
        user_message (str): The message from the user
        system_message (str, optional): System instructions. Defaults to a helpful assistant.
        max_tokens (int, optional): Maximum number of tokens to generate. Defaults to 100.

    Returns:
        str: The model's response
    """
    # Format the prompt using the existing function
    formatted_prompt = format_chat(
        tokenizer, system_message, user_message, return_token_type_map=False
    )

    # Tokenize the prompt

    # Generate a response
    with torch.inference_mode():
        generated_text = model.generate(
            formatted_prompt,
            max_new_tokens=max_tokens,
            temperature=temperature,
            verbose=verbose,
        )
    if return_input:
        return generated_text, formatted_prompt
    else:
        return generated_text


def is_first_token_match(output, inp, tokenizer, target_token_str):
    """
    Checks if the first token of the model's output matches the first token of a target string.

    Args:
        output (str): The model's generated output
        tokenizer: The tokenizer used to encode/decode tokens
        target_token_str (str): The target string to check against (e.g., "1")

    Returns:
        bool: True if the first token matches the first token of target_token_str, False otherwise
    """
    # TODO is gross should clean this up
    inp_tokens = tokenizer.encode(inp)
    inp = tokenizer.decode(inp_tokens, skip_special_tokens=True)
    output = output.replace(inp, "")

    # Tokenize the assistant's response and get the first token
    assistant_tokens = tokenizer.encode(output, add_special_tokens=False)
    if not assistant_tokens:
        return False

    # Tokenize the target string and get the first token
    target_tokens = tokenizer.encode(target_token_str, add_special_tokens=False)
    if not target_tokens:
        return False

    # Compare the first tokens
    return assistant_tokens[0] == target_tokens[0]


def test_first_token_responses(
    model,
    tokenizer,
    system_message,
    text_lists,
    expect_matches,
    category_names=None,
    max_tokens=1,
):
    """
    Tests model responses for multiple lists of texts.

    Args:
        model: The model to test
        tokenizer: Tokenizer for the model
        system_message: The system prompt to use
        text_lists: List of lists of strings to test
        expect_matches: List of booleans indicating whether each list should match "1" token
        category_names: Optional names for each category (for reporting)
        max_tokens: Maximum number of tokens to generate

    Returns:
        Dictionary with results for each category
    """
    if category_names is None:
        category_names = [f"category_{i}" for i in range(len(text_lists))]

    results = {}
    for i, (texts, expect_match, category) in enumerate(
        zip(text_lists, expect_matches, category_names)
    ):
        results[category] = {"correct": 0, "total": len(texts)}

        print(f"\nTesting {category}:")
        for j, text in enumerate(texts):
            print(f"  Testing example {j + 1}/{len(texts)}", end="\r")
            response, inp_string = get_model_response(
                model,
                tokenizer,
                text,
                system_message,
                max_tokens=max_tokens,
                return_input=True,
            )
            is_one = is_first_token_match(response, inp_string, tokenizer, "1")
            if is_one == expect_match:
                results[category]["correct"] += 1

    # Calculate accuracies
    for category in results:
        results[category]["accuracy"] = (
            results[category]["correct"] / results[category]["total"] * 100
        )

    print("\n\nResults:")
    for category in results:
        print(
            f"{category}: {results[category]['correct']}/{results[category]['total']} correct ({results[category]['accuracy']:.2f}%)"
        )

    # Calculate overall accuracy
    total_correct = sum(results[cat]["correct"] for cat in results)
    total_examples = sum(results[cat]["total"] for cat in results)
    overall_accuracy = total_correct / total_examples * 100
    print(
        f"Overall accuracy: {total_correct}/{total_examples} correct ({overall_accuracy:.2f}%)"
    )

    return results


class LogProbDiffCalculator:
    def __init__(
        self,
        model,
        tokenizer,
        target_word,
        negative_word,
        max_num_tokens=None,
        literal_diff=False,
    ):
        """
        model: a language model that accepts input embeddings (shape: [batch, seq_length, emb_dim])
               and returns logits when called with input_embeddings, e.g.,
               self.model(input_embeddings, start_at_layer=0, return_type="logits")
               Also exposes self.model.embed(token_ids) to obtain embeddings.
        tokenizer: a tokenizer that converts strings to tokens.
        target_word: the word whose likelihood you want to boost.
        negative_word: the word to compare against.
        max_num_tokens: if not None, crop each candidate's tokens to this maximum number.
                        For example, if set to 1, only the first token of each candidate is used.
        literal_diff: if True, the objective is computed as the raw logit difference for the first candidate token,
                      i.e. logit(target token 1 | x) - logit(negative token 1 | x), ignoring additional tokens.
        """
        self.model = model
        self.tokenizer = tokenizer
        self.literal_diff = literal_diff

        # Pre-tokenize candidate words (without adding special tokens)
        self.target_tokens = tokenizer.encode(target_word, add_special_tokens=False)
        self.negative_tokens = tokenizer.encode(negative_word, add_special_tokens=False)

        # Optionally crop candidate tokens to max_num_tokens
        if max_num_tokens is not None:
            self.target_tokens = self.target_tokens[:max_num_tokens]
            self.negative_tokens = self.negative_tokens[:max_num_tokens]

        # Determine maximum candidate length after potential cropping
        self.max_len = max(len(self.target_tokens), len(self.negative_tokens))

        # Pad the tokenized candidate sequences to have equal length.
        pad_id = tokenizer.pad_token_id
        self.target_tokens_padded = self.target_tokens + [pad_id] * (
            self.max_len - len(self.target_tokens)
        )
        self.negative_tokens_padded = self.negative_tokens + [pad_id] * (
            self.max_len - len(self.negative_tokens)
        )

        # Create masks for the candidate tokens (1 for real token, 0 for padded)
        self.target_mask = [1] * len(self.target_tokens) + [0] * (
            self.max_len - len(self.target_tokens)
        )
        self.negative_mask = [1] * len(self.negative_tokens) + [0] * (
            self.max_len - len(self.negative_tokens)
        )

        # Convert to tensors (device assignment will occur in compute_objective)
        self.target_tokens_tensor = torch.tensor(self.target_tokens_padded)
        self.negative_tokens_tensor = torch.tensor(self.negative_tokens_padded)
        self.target_mask_tensor = torch.tensor(self.target_mask, dtype=torch.float)
        self.negative_mask_tensor = torch.tensor(self.negative_mask, dtype=torch.float)

    def compute_objective(self, input_embeddings, return_logits=False):
        """
        Given a batch of input embeddings of shape [batch_size, seq_length, emb_dim], this function
        doubles the batch by appending both the target and negative candidate embeddings (obtained via self.model.embed)
        to each input. It then computes an objective value defined as either:

            (a) If literal_diff is False:
                objective = log p(target candidate | x) - log p(negative candidate | x)
                (computed by summing the log probabilities over candidate tokens)

            (b) If literal_diff is True:
                objective = raw logit(target candidate token 1 | x) - raw logit(negative candidate token 1 | x)

        Maximizing the returned objective will boost the target word's likelihood relative to the negative word.

        Args:
            input_embeddings: Tensor of shape [batch_size, seq_length, emb_dim]. It is assumed that
                              all samples have the same sequence length.
            return_logits: If True, also return the logits for the original input plus one position.

        Returns:
            If return_logits=False: The mean objective over the batch.
            If return_logits=True: A tuple of (mean_objective, input_logits) where input_logits are
                                  the logits for the original input sequence plus one token.
        """
        device = next(self.model.parameters()).device
        input_embeddings = input_embeddings.to(device)
        batch_size, seq_length, emb_dim = input_embeddings.shape

        # Use the fixed sequence length from the input embeddings.
        L = seq_length

        # Precompute candidate embeddings using the model's embed function.
        target_candidate_embeds = self.model.embed(
            self.target_tokens_tensor.to(device).unsqueeze(0)
        ).squeeze(0)
        negative_candidate_embeds = self.model.embed(
            self.negative_tokens_tensor.to(device).unsqueeze(0)
        ).squeeze(0)
        target_candidate_mask = self.target_mask_tensor.to(device)
        negative_candidate_mask = self.negative_mask_tensor.to(device)

        appended_input_embeds = []
        appended_attention_masks = []

        # For every sample, create two sequences:
        # one with target candidate embeddings appended, and one with negative candidate embeddings appended.
        for i in range(batch_size):
            # All tokens in input_embeddings are valid since the seq length is fixed.
            input_embeds_i = input_embeddings[i]  # shape: [L, emb_dim]

            # Append candidate embeddings.
            target_appended = torch.cat(
                [input_embeds_i, target_candidate_embeds], dim=0
            )
            negative_appended = torch.cat(
                [input_embeds_i, negative_candidate_embeds], dim=0
            )
            appended_input_embeds.extend([target_appended, negative_appended])

            # Create corresponding attention masks:
            # For the input part, all tokens are real (1's).
            # For the candidate part, use the precomputed candidate mask.
            target_attention = torch.cat(
                [torch.ones(L, device=device), target_candidate_mask], dim=0
            )
            negative_attention = torch.cat(
                [torch.ones(L, device=device), negative_candidate_mask], dim=0
            )
            appended_attention_masks.extend([target_attention, negative_attention])

        # Pad all appended sequences to the same total length.
        max_total_length = max(seq.shape[0] for seq in appended_input_embeds)
        padded_input_embeds = []
        padded_attention_masks = []
        for seq_embeds, seq_mask in zip(
            appended_input_embeds, appended_attention_masks
        ):
            pad_length = max_total_length - seq_embeds.shape[0]
            padded_seq_embeds = torch.cat(
                [
                    seq_embeds,
                    torch.zeros(
                        (pad_length, emb_dim), device=device, dtype=seq_embeds.dtype
                    ),
                ],
                dim=0,
            )
            padded_seq_mask = torch.cat(
                [
                    seq_mask,
                    torch.zeros(pad_length, device=device, dtype=seq_mask.dtype),
                ],
                dim=0,
            )
            padded_input_embeds.append(padded_seq_embeds)
            padded_attention_masks.append(padded_seq_mask)

        # Stack into batch tensors.
        appended_embeds_tensor = torch.stack(
            padded_input_embeds, dim=0
        )  # shape: (2*batch_size, max_total_length, emb_dim)

        # Forward pass through the model.
        # Since this is a tlens model, we call it with the input embeddings directly.
        logits = self.model(
            appended_embeds_tensor, start_at_layer=0, return_type="logits"
        )

        # Store input logits if requested
        input_logits = None
        if return_logits:
            input_logits = logits[::2, :seq_length, :]

        objectives = []
        for i in range(batch_size):
            if self.literal_diff:
                # For literal_diff mode, we use only the first candidate token (position L)
                # and directly take the raw logit corresponding to the candidate token id.
                logit_target = logits[2 * i, L, self.target_tokens_tensor[0].to(device)]
                logit_negative = logits[
                    2 * i + 1, L, self.negative_tokens_tensor[0].to(device)
                ]
                sample_objective = logit_target - logit_negative
            else:
                # Compute over all candidate tokens (i.e. log probability product).
                # Extract logits corresponding to candidate tokens.
                logits_target = logits[
                    2 * i, L : L + self.max_len
                ]  # shape: (max_len, vocab_size)
                logits_negative = logits[2 * i + 1, L : L + self.max_len]

                # Compute log softmax over the vocabulary.
                log_probs_target = torch.log_softmax(logits_target, dim=-1)
                log_probs_negative = torch.log_softmax(logits_negative, dim=-1)

                # Gather the log probabilities for the correct candidate tokens.
                token_log_probs_target = log_probs_target.gather(
                    dim=-1, index=self.target_tokens_tensor.to(device).unsqueeze(1)
                ).squeeze(1)
                token_log_probs_negative = log_probs_negative.gather(
                    dim=-1, index=self.negative_tokens_tensor.to(device).unsqueeze(1)
                ).squeeze(1)

                # Mask out the padded candidate tokens.
                token_log_probs_target = (
                    token_log_probs_target * self.target_mask_tensor.to(device)
                )
                token_log_probs_negative = (
                    token_log_probs_negative * self.negative_mask_tensor.to(device)
                )

                # Sum the log probabilities over the candidate tokens.
                sum_log_prob_target = token_log_probs_target.sum()
                sum_log_prob_negative = token_log_probs_negative.sum()

                sample_objective = sum_log_prob_target - sum_log_prob_negative

            # Only append the actual tensor objective, not any empty list
            objectives.append(sample_objective)

        mean_objective = torch.stack(objectives)
        if return_logits:
            return mean_objective, input_logits
        else:
            return mean_objective


def compute_log_prob_diff(
    model, tokenizer, log_prob_diff_calculator, text, device="cuda"
):
    # convert text to input ids
    input_ids = tokenizer.encode(text, return_tensors="pt").to(device)
    # compute log prob diff
    input_embeddings = model.embed(input_ids)
    return log_prob_diff_calculator.compute_objective(input_embeddings)


class TlensLogProbDiffRunner(IRunner):
    def __init__(
        self,
        model,
        tokenizer,
        target_word: str,
        negative_word: str,
        max_num_tokens: Optional[int] = None,
        literal_diff: bool = False,
    ):
        """
        Runner that uses LogProbDiffCalculator to compute the objective.

        Args:
            model: The model to use
            tokenizer: The tokenizer for the model
            target_word: Word whose likelihood you want to boost
            negative_word: Word to compare against
            max_num_tokens: If not None, crop candidates to this max number of tokens
            literal_diff: If True, use raw logit difference for the first token only
        """
        self.model = model
        self.tokenizer = tokenizer
        self.target_word = target_word
        self.negative_word = negative_word

        # Initialize the LogProbDiffCalculator
        self.log_prob_diff_calculator = LogProbDiffCalculator(
            model=model,
            tokenizer=tokenizer,
            target_word=target_word,
            negative_word=negative_word,
            max_num_tokens=max_num_tokens,
            literal_diff=literal_diff,
        )

    def run_with_embeddings(self, input_embeddings: torch.Tensor) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        Dict[str, Any],
    ]:
        """
        Run the model with embedded inputs and return the log prob difference as target.

        Args:
            input_embeddings: Embedded input tensor

        Returns:
            Tuple of (target, logits, {}) where target is the log prob diff between target_word and negative_word
        """
        # Compute the log probability difference using the calculator, also getting logits
        target, model_logits = self.log_prob_diff_calculator.compute_objective(
            input_embeddings, return_logits=True
        )

        # Return target, logits, and empty dict for additional info
        return target, model_logits, {}

    def one_hot_to_embed(self, one_hot: torch.Tensor) -> torch.Tensor:
        """Convert one-hot encoded input to embedded input."""
        embed_matrix = self.model.embed.W_E
        return torch.matmul(one_hot, embed_matrix)

    def int_ids_to_embed(self, int_ids: torch.Tensor) -> torch.Tensor:
        """Convert integer token IDs to embedded input."""
        return self.model.embed(int_ids)


def process_text_with_placeholder(tokenizer, text, start_context):
    # Find the placeholder
    placeholder = " {0}."
    placeholder_start = text.find(placeholder)
    if placeholder_start == -1:
        raise ValueError(f"Placeholder '{placeholder}' not found in text")

    placeholder_end = placeholder_start + len(placeholder)

    # Get everything before and after the placeholder
    prefix_marker = text[:placeholder_start]
    postfix_marker = text[placeholder_end:]

    # Call our function
    return get_fixed_positions(
        tokenizer, text.format(start_context), prefix_marker, postfix_marker
    )


def format_word_list(words):
    """
    Format a list of words into the pattern ({0}, {1} or {2}) where the numbers
    are placeholders for the words. Handles lists of any length.

    Args:
        words: List of words to format

    Returns:
        Formatted string with words in the pattern ({word1}, {word2} or {word3})
    """
    if not words:
        return "()"

    if len(words) == 1:
        return f"({words[0]})"

    if len(words) == 2:
        return f"({words[0]} or {words[1]})"

    # For 3 or more words, list all but the last with commas, then "or" before the last
    formatted = ", ".join([f"{word}" for word in words[:-1]])
    return f"({formatted} or {words[-1]})"


SYSTEM_MESSAGE_OR = """Rules:
1. If and only if the message includes one of the following words {0} then reply with only the digit "1"
2. For all other messages: respond with the digit "2" """


def test_epo_output(
    output_str,
    word_list,
):

    return_dict = {}

    output_str = output_str.lower()
    for word in word_list:
        if word.lower() in output_str:
            return_dict[word] = 1
        else:
            return_dict[word] = 0

    return return_dict


def test_all_epo_output(list_str, word_list):
    final_dict = {}
    for word in word_list:
        final_dict[word] = 0
    for output_str in list_str:
        cur_dict = test_epo_output(output_str, word_list)
        for key in cur_dict:
            if cur_dict[key] == 1:
                final_dict[key] = 1
    return final_dict


def test_system_message_call(model, tokenizer, user_message, word_list):
    system_message = SYSTEM_MESSAGE_OR.format(format_word_list(word_list))
    response, inp = get_model_response(
        model, tokenizer, user_message, system_message, return_input=True
    )
    inp = strip_special_tokens(tokenizer, inp)

    response = response.replace(inp, "")
    includes = False
    for word in word_list:
        if word in user_message:
            includes = True

    print("the model response is", response)

    if includes and "1" in response and "2" not in response:
        return True
    elif not includes and "2" in response and "1" not in response:
        return True

    else:
        return False


def strip_special_tokens(tokenizer, text):
    """
    Removes any special tokens or markers added by the chat formatter.

    Args:
        tokenizer: The tokenizer used for encoding/decoding
        text: Text string that might contain special tokens or markers

    Returns:
        Cleaned text without special tokens or formatting markers
    """
    # Get all special tokens from the tokenizer
    special_tokens = []

    # Collect all special tokens the tokenizer might use
    for attr in dir(tokenizer):
        if attr.endswith("_token") and isinstance(getattr(tokenizer, attr), str):
            token = getattr(tokenizer, attr)
            if token:
                special_tokens.append(token)

    # Add common chat formatting markers that might not be in special_tokens
    common_markers = [
        "<s>",
        "</s>",
        "<pad>",
        "<bos>",
        "<eos>",
        "<start_of_turn>",
        "<end_of_turn>",
        "<assistant>",
        "<user>",
        "<system>",
        "[INST]",
        "[/INST]",
        "[ASST]",
        "[/ASST]",
        "[SYS]",
        "[/SYS]",
        "<|im_start|>",
        "<|im_end|>",
        "<|assistant|>",
        "<|user|>",
        "<|system|>",
    ]
    special_tokens.extend(common_markers)

    # Remove all special tokens from the text
    cleaned_text = text
    for token in special_tokens:
        cleaned_text = cleaned_text.replace(token, "")

    # Clean up consecutive spaces while preserving newlines
    # Split by newlines, clean each line, then rejoin
    lines = cleaned_text.split("\n")
    cleaned_lines = []
    for line in lines:
        # Replace consecutive spaces with a single space
        cleaned_line = " ".join(line.split())
        cleaned_lines.append(cleaned_line)

    # Rejoin with the original newlines
    cleaned_text = "\n".join(cleaned_lines)

    return cleaned_text
