from enum import Enum, auto
from typing import List, Optional, Any, Dict

import numpy as np
import torch
from llada.inpaint import LLaDAInpainter
from pydantic import BaseModel, Field
from transformers import PreTrainedTokenizerBase

import wandb
from custom_dreamy.selector import Selector
from custom_dreamy.state import State
from eliciting_contexts.utils.call_openai import ImageChatHistory, call_model
from eliciting_contexts.benchmark.external.utils.logger import logger


class ICallback:
    def __call__(
        self,
        i: int,
        state: State,
        last_runtime: Optional[float],
        history: Any,
        selector: Optional[Selector],
        final: bool = False,
    ) -> Optional[bool]:
        raise NotImplementedError("ICallback is an abstract class")


EPO_HELPER_PROMPT = """You are a specialized text generation assistant helping to create inputs that maximize neural network feature activation.

Your task is to generate examples that will strongly activate a specific feature.

Below are examples of text and their corresponding activation score (target)(ranked from highest to lowest activation):

{examples_str}

Please analyze these examples carefully to understand what causes high activation. Then, generate {num_sentences} new examples that could potentially achieve strong activation of this feature.

When creating your examples:
- Try not to be fooled by low activation examples that are lower on the list
- ENSURE THAT SOME OF YOUR SAMPLES ARE JUST GRAMATICAL PARAPHRASES OF THE HIGHER ACTIVATING EXAMPLES
- Look for patterns in the higher activating examples
- Make at least one of your examples closely resemble the highest-activating example in structure and content
- Create diverse examples that capture different aspects of what might activate the feature
- Ensure your examples are similar in length to the provided examples
- Use natural language that sounds realistic with good grammar (even if the examples are not realistic!)

Note that your examples will be truncated or padded to match a specific token length. For each example you create, please indicate whether to truncate from the left or right if necessary, based on which part you believe is less important for activating the target feature.

First, describe your understanding of what causes high activation for this feature and write down any notes, think through what you will do. Then, generate your {num_sentences} examples, each with a truncation preference (left/right).
"""

EPO_STORIES_HELPER_PROMPT = """You are a specialized text generation assistant helping to create inputs that maximize certain next word logits and minimise other next word logits for neural networks.

The task is to modify a story to make the model predict the word '{desired_text}' and minimise the logits of '{undesired_text}.' You can only modify one sentence in the story that is placed within the story.

Here is the template for the story: {full_template}. The sentence you are to modify is placed inside the 'INSERT TEXT HERE' section of this template.

The goal is to improve upon the current versions of the variable text that you should use as a starting point for you to generate {num_sentences} new sentences.

Here are the current sentences: {current_epo_str}

Each of your sentences should be based on one of these current sentences.

When creating your examples:
- Focus on making each sentence as fluent (i.e. grammatically correct and natural sounding) as possible
- Make a few of your examples closely resemble a current variable context in semantic meaning, but correct for fluency.
- Make a few of your examples keep words that you deem most relevant to the goal of predicting the desired text, but change everything else.
- For a few of your examples, you have license to change the semantic meaning more to maximise desired text.
- Create diverse examples that capture different aspects of what might cause the model to predict the desired text.
- Ensure your examples are similar in length to the provided examples.
- Use natural language that sounds realistic with good grammar (even if the examples are not realistic!)

Write down any thoughts you have about the task and then provide {num_sentences} new sentences that are different from the original sentence and are different from each other, that could potentially achieve this goal.

Note that your examples will be truncated or padded to match a specific token length. For each example you create, please indicate whether to truncate from the left or right if necessary, based on which part you believe is less important for activating the target feature.

First, describe your understanding of what causes high activation for this feature and write down any notes, think through what you will do. Then, generate your {num_sentences} examples, each with a truncation preference (left/right).
"""


class TruncationPreference(str, Enum):
    LEFT = "left"
    RIGHT = "right"


class SampleWithTruncation(BaseModel):
    """Structure for a single sample with truncation preference."""

    text: str = Field(
        ..., description="The sample text that might activate the feature"
    )
    truncation_preference: TruncationPreference = Field(
        ...,
        description="Whether to truncate from 'left' or 'right' if necessary",
    )


class EPOHelperOutput(BaseModel):
    """Structure for EPO helper prompt output with samples and truncation preferences."""

    thoughts: str = Field(
        ..., description="Think through the task and write notes here"
    )
    samples: List[SampleWithTruncation] = Field(
        ..., description="List of samples with truncation preferences"
    )


class PromptType(Enum):
    SAE = "sae"
    STORIES = "stories"


class ModelHelperCallback(ICallback):
    def __init__(
        self,
        tokenizer: PreTrainedTokenizerBase,
        run_every: int = 76,
        model_name: str = "gpt-4o",
        temperature: float = 1.0,
        prompt_type: PromptType = PromptType.SAE.value,
        stories_context: Optional[Dict[str, str]] = None,
        run_on_last_step: bool = False,
        num_mutations = 32,
    ):
        if prompt_type == PromptType.STORIES and stories_context is None:
            raise ValueError("stories_context must be provided for PromptType.STORIES")

        self.model_name = model_name
        self.tokenizer = tokenizer
        self.run_every = run_every
        self.temperature = temperature
        self.stories_context = stories_context
        self.prompt_type = prompt_type
        self.run_on_last_step = run_on_last_step
        self.num_mutations = num_mutations
    def __call__(
        self,
        i: int,
        state: State,
        last_runtime: Optional[float],
        history: Any,
        selector: Selector,
        final: bool = False,
    ) -> None:

        if not final:
            if not (i % self.run_every == 0 and i > 0):
                return

        if final and not self.run_on_last_step:
            return

        items_with_targets = []
        pop_size = state.ids.shape[0]
        seq_len = state.ids.shape[1]
        for item_idx, seq in enumerate(state.ids):
            # Remove padding tokens (BOS and EOS tokens (bos will get added back anyway)

            # TODO Rob: nvm i don't know if this is safe assuming these tokens are added elsewhere... maybe doesn't matter
            mask = (seq != self.tokenizer.eos_token_id) & (
                seq != self.tokenizer.bos_token_id
            )
            seq_no_padding = seq[mask]

            # TODO Rob: use fixed positions here to only pass in relevant portion to system prompt
            # Decode the sequence to text
            decoded_text = self.tokenizer.decode(seq_no_padding)

            # Store target value with text
            target_value = state.target[item_idx].item()
            items_with_targets.append((target_value, decoded_text))

        # Sort by target value (highest to lowest)
        items_with_targets.sort(reverse=True)

        # Format the sorted items
        decoded_with_targets = [
            f"{text} (Target: {target:.4f})" for target, text in items_with_targets
        ]

        # Remove duplicates while preserving order
        unique_texts = list(dict.fromkeys(decoded_with_targets))

        if self.prompt_type == PromptType.STORIES.value:
            selected_prompt = EPO_STORIES_HELPER_PROMPT
            prompt_format_args: Dict[str, Any] = {
                "desired_text": self.stories_context["desired_text"],
                "undesired_text": self.stories_context["undesired_text"],
                "full_template": self.stories_context["template"].format("INSERT TEXT HERE"),
                "num_sentences": pop_size,
                "current_epo_str": chr(10).join(
                    [f"{idx + 1}. {example}" for idx, example in enumerate(unique_texts)]
                ),
            }
            variable_seq_len = (~self.stories_context["fixed_positions"]).sum().item()
        elif self.prompt_type == PromptType.SAE.value:
            selected_prompt = EPO_HELPER_PROMPT
            prompt_format_args: Dict[str, Any] = {
                "examples_str": chr(10).join(
                    [f"{idx + 1}. {example}" for idx, example in enumerate(unique_texts)]
                ),
                "num_sentences": pop_size,
            }
            variable_seq_len = seq_len
        else:
            # Should not happen with Enum, but good practice
            raise ValueError(f"Unsupported PromptType: {self.prompt_type}")

        # Use the chosen prompt template
        prompt_text = selected_prompt.format(**prompt_format_args)
        print(prompt_text)

        chat_history = ImageChatHistory()
        chat_history.add_system_msg(prompt_text)

        assist_success, assist_content, error = call_model(
            chat_history,
            structured_output=EPOHelperOutput,
            temperature=self.temperature,
            model=self.model_name,
        )
        if not assist_success:
            print(error)
            return

        new_samples_list = []
        logger.info(chat_history)
        logger.info(assist_content.thoughts)

        for sample_idx, assist_sample in enumerate(assist_content.samples):
            logger.info(f"Sample {sample_idx}:")
            logger.info(assist_sample.text)
            add_special_tokens = False if self.prompt_type == PromptType.STORIES.value else True
            assist_sample_encoded = self.tokenizer.encode(assist_sample.text, return_tensors="pt", add_special_tokens=add_special_tokens)[0]
            assist_sample_len = assist_sample_encoded.size(0)

            logger.info(f"The length of the encoded text is {assist_sample_len} and the length of "
                        f"the variable sequence is {variable_seq_len}")
            if assist_sample_len > variable_seq_len:
                logger.info(f"The length of the generated text is longer than the template so it needs to be truncated.")
                if assist_sample.truncation_preference.lower() == "left":
                    logger.info(f"Truncating to the left.")
                    assist_sample_encoded = assist_sample_encoded[-variable_seq_len:]
                else:
                    logger.info(f"Truncating to the right.")
                    assist_sample_encoded = assist_sample_encoded[:variable_seq_len]
            elif assist_sample_len < variable_seq_len:
                logger.info(f"The length of the generated text is shorter than the template so it needs to be padded with EOS tokens.")
                padding = torch.full(
                    (variable_seq_len - assist_sample_len,), self.tokenizer.eos_token_id, dtype=torch.long
                )
                assist_sample_encoded = torch.cat([assist_sample_encoded, padding])

            new_samples_list.append(assist_sample_encoded)

        num_new_samples = len(new_samples_list)

        if num_new_samples > pop_size:
            logger.info(f"Too many samples generated, keeping only the first {pop_size}")
            new_samples_list = new_samples_list[:pop_size]
        elif num_new_samples < pop_size:
            # TODO EdS: I did not check this code, under the assumption that it if it works it still should
            logger.info(f"Too few samples, fill with highest-target samples from current state")

            # Create tuples of (target_value, index)
            target_with_idx = [
                (state.target[i].item(), i) for i in range(state.ids.shape[0])
            ]
            # Sort by target value (highest first)
            target_with_idx.sort(reverse=True)

            # Add highest-target samples until we reach pop_size
            additional_needed = pop_size - num_new_samples
            for i in range(additional_needed):
                _, idx = target_with_idx[
                    i % len(target_with_idx)
                ]  # Use modulo to cycle if needed
                new_samples_list.append(state.ids[idx])

        new_samples = torch.stack(new_samples_list).to(state.ids.device)

        if self.prompt_type == PromptType.STORIES.value:
            initial_ids = self.stories_context["initial_token_ids"]
            fixed_positions = self.stories_context["fixed_positions"]
            assert len(initial_ids) == len(fixed_positions), \
                f"Length mismatch: initial_ids ({len(initial_ids)}) vs fixed_positions ({len(fixed_positions)})"
            assert new_samples.shape[1] == initial_ids[~fixed_positions].shape[0], \
                f"The length of the new samples ({new_samples.shape[1]}) does not match the length of the variable context window ({initial_ids[~fixed_positions].shape[0]})"

            full_samples = initial_ids.clone().to(state.ids.device)
            full_samples = full_samples.unsqueeze(0).repeat(new_samples.shape[0], 1)
            full_samples[:, ~fixed_positions] = new_samples
            new_samples = full_samples

        state.update_from(selector.setup(new_samples))


        if final:
            size = self.num_mutations*pop_size
            history._insert(
                torch.cat((state.ids, torch.zeros((size, state.ids.shape[1]), device=state.ids.device, dtype=state.ids.dtype)), dim=0).to(torch.int64),
                torch.cat((state.target, torch.zeros((size, ), device=state.target.device, dtype=state.target.dtype)), dim=0).to(torch.float32),
                torch.cat((state.xentropy, torch.zeros((size, ), device=state.target.device, dtype=state.xentropy.dtype)), dim=0).to(torch.float32),
                torch.zeros((pop_size,)).to(torch.bool).to(state.ids.device),
                0.0,
                torch.zeros((pop_size,)).to(torch.float32).to(state.ids.device),
            )


class InpaintingCallback(ICallback):
    def __init__(
        self,
        tokenizer,
        inpaint_every=5,
        replace_percent=0.75,
        replace_prob=0.5,
        token_per_step=4,
        unmasking="high_confidence",
        model_name="GSAI-ML/LLaDA-8B-Instruct",
        device="cuda",
        torch_dtype=torch.bfloat16,
    ):
        self.tokenizer = tokenizer
        self.inpainter = LLaDAInpainter(
            model_name=model_name,
            device=device,
            torch_dtype=torch_dtype,
        )
        self.inpaint_every = inpaint_every
        self.replace_percent = replace_percent
        self.replace_prob = replace_prob
        self.token_per_step = token_per_step
        self.unmasking = unmasking

    def __call__(self, i, state, last_runtime, history, selector, final=False):
        print("\n--- Debugging tokenized sequences ---")
        print("INPUT =============================")

        for idx, seq in enumerate(state.ids):
            decoded_text = self.tokenizer.decode(seq)
            print(f"Sequence {idx}: {decoded_text}")

        if "per_token_target" not in state.extra:
            raise ValueError("per_token_target not found in state.extra")

        with torch.no_grad():
            per_token_target = state.extra["per_token_target"]  # shape batch seq
            print(per_token_target.shape)

            # Skip if not time to inpaint yet
            if i % self.inpaint_every != 0 or final or i == 0:
                return

            # Create mask for tokens with lowest activation
            # Negate the scores so that top_k gives us the bottom_k values
            negated_scores = -per_token_target

            # Get top k values and indices (equivalent to bottom k of original)
            top_k = int(self.replace_percent * per_token_target.shape[1])
            _, top_k_indices = torch.topk(negated_scores, k=top_k, dim=1)

            # Create a mask for these indices
            mask = torch.zeros_like(per_token_target, dtype=torch.bool)
            batch_indices = (
                torch.arange(per_token_target.shape[0], device=per_token_target.device)
                .unsqueeze(1)
                .expand(-1, top_k)
            )

            # Set the mask to True for the identified indices
            mask[batch_indices.flatten(), top_k_indices.flatten()] = True

            # Create a random mask with probability self.replace_prob
            mask_random = torch.rand_like(per_token_target) < self.replace_prob

            # Combine the masks with logical AND
            mask = mask & mask_random

            # Now use the mask with the inpainter
            # TODO: Implement the inpainting logic using self.inpainter and the mask

            new_ids = self.inpainter.inpaint_with_alt_tokenizer(
                self.tokenizer,
                state.ids,  # seems wrong but ok
                mask,
                token_per_step=self.token_per_step,
                unmasking=self.unmasking,
            )

        # recompute gradients (which for some reason is part of selector)
        state.update_from(selector.setup(new_ids))

        # Debug: Display each tokenized sequence
        print("OUTPUT =============================")
        for idx, seq in enumerate(state.ids):
            decoded_text = self.tokenizer.decode(seq)
            print(f"Sequence {idx}: {decoded_text}")
        print("--- End of debugging output ---\n")


class WandbEPOCallback(ICallback):
    """
    Callback class for EPO that logs metrics to W&B during training.
    """

    def __init__(
        self,
        cache_run,
        model,
        tokenizer,
        x_penalty_min,
        x_penalty_max,
        name="metrics",
        per_pop_log=True,
    ):
        self.cache_run = cache_run
        self.model = model
        self.tokenizer = tokenizer
        self.x_penalty_min = x_penalty_min
        self.x_penalty_max = x_penalty_max
        # Use provided name or default to a unique identifier
        self.name = name
        self.per_pop_log = per_pop_log

        # Define a custom x-axis for this set of metrics
        self.x_axis_name = f"{self.name}_step"

        # Tell wandb to use this custom x-axis for all metrics with this prefix
        wandb.define_metric(f"{self.name}/*", step_metric=self.x_axis_name)
        if self.per_pop_log:
            wandb.define_metric(
                f"{self.name}/population/*", step_metric=self.x_axis_name
            )

    def __call__(self, i, state, last_runtime, history, selector, final=False):
        # Log the step value for our custom x-axis
        wandb.log({self.x_axis_name: i})

        if last_runtime is not None:
            wandb.log({f"{self.name}/runtime": last_runtime})

        # Log metrics for each population member
        if self.per_pop_log:
            for pop_idx in range(state.ids.shape[0]):
                metrics = {
                    f"{self.name}/population/member_{pop_idx}/xentropy": state.xentropy[
                        pop_idx
                    ].item(),
                    f"{self.name}/population/member_{pop_idx}/target": state.target[
                        pop_idx
                    ].item(),
                    f"{self.name}/population/member_{pop_idx}/text": self.tokenizer.decode(
                        state.ids[pop_idx]
                    ),
                }
                wandb.log(metrics)

        # Log overall statistics
        metrics = {
            f"{self.name}/xentropy/mean": state.xentropy.mean().item(),
            f"{self.name}/xentropy/min": state.xentropy.min().item(),
            f"{self.name}/xentropy/max": state.xentropy.max().item(),
            f"{self.name}/target/mean": state.target.mean().item(),
            f"{self.name}/target/min": state.target.min().item(),
            f"{self.name}/target/max": state.target.max().item(),
        }

        wandb.log(metrics)

        return False  # Continue training


class ParetoCallback(ICallback):
    """
    Callback class for displaying the Pareto frontier prompts.
    """

    def __init__(
        self, cache_run, model, tokenizer, x_penalty_min, x_penalty_max, fixed_positions
    ):
        self.cache_run = cache_run
        self.model = model
        self.tokenizer = tokenizer
        self.x_penalty_min = x_penalty_min
        self.x_penalty_max = x_penalty_max
        self.fixed_positions = fixed_positions
        if fixed_positions is not None:
            self.fixed_positions = torch.tensor(fixed_positions, dtype=torch.bool)

    def __call__(self, i, state, last_runtime, history, selector, final=False):
        if last_runtime is not None:
            print("runtime: {:.2f} seconds".format(last_runtime))
        print(f"\nbeginning step {i}, current pareto frontier prompts:")
        last_idx = None

        Xvs = torch.exp(
            torch.linspace(
                np.log(self.x_penalty_min / 10.0),
                np.log(self.x_penalty_max * 10.0),
                200,
            )
        ).to(state.target.device)
        loss = -state.target[None] + Xvs[:, None] * state.xentropy[None]
        idxs = loss.argmin(dim=1)
        for i_idx in range(len(Xvs)):
            idx = idxs[i_idx]
            if idx == last_idx:
                continue
            if self.fixed_positions is not None:
                self.fixed_positions = self.fixed_positions.to(state.ids.device)
                text = self.tokenizer.decode(state.ids[idx, ~self.fixed_positions])
            else:
                text = self.tokenizer.decode(state.ids[idx])
            last_token = self.tokenizer.decode(state.final_token[idx])
            print(
                f"penalty={Xvs[i_idx]:.2f} xentropy={state.xentropy[idx]:.2f} target={state.target[idx]:.2f} {repr(text + '[' + last_token + ']')}"
            )
            last_idx = idx

        return False  # Continue training


class CombinedCallback(ICallback):
    """Combines multiple callbacks into one ICallback implementation."""

    def __init__(self, callbacks, verbose=False):
        self.callbacks = callbacks
        self.verbose = verbose

    def __call__(self, i, state, last_runtime, history, selector=None, final=False):
        results = []
        for cb in self.callbacks:
            if cb is not None:
                result = cb(i, state, last_runtime, history, selector, final=final)
                results.append(result)
        # Return True if any callback returns True
        return any(results)



# periodically adds a term to the runner
class DiverseCallback(ICallback):
    def __init__(self, divergence_runner, state_ids=None,run_every=10):
        self.divergence_runner = divergence_runner
        self.run_every = run_every
        self.reset_state_ids = state_ids

    def __call__(
        self,
        i,
        state: State,
        last_runtime,
        history,
        selector: Selector,
        final=False,
    ):
        # Only run every run_every steps
        if (i-1)%self.run_every or (i-1)==0:
            return False

        # Get embeddings for the first sequence in the batch
        all_same = True
        for id in state.ids:
            if not torch.all(id == state.ids[0]):
                all_same = False
                break
        ids = state.ids[0].unsqueeze(0)  # Add batch dimension back
        input_embeddings = self.divergence_runner.runner.int_ids_to_embed(ids)

        # Run the wrapped runner to get intermediate residuals
        _, _, additional_outputs = self.divergence_runner.runner.run_with_embeddings(input_embeddings)

        if "intermediate_residual" in additional_outputs:
            # Get the intermediate residual
            intermediate_residual = additional_outputs["intermediate_residual"]

            avg_residual = intermediate_residual.mean(dim=1)  # Average over sequence dimension

            # Squeeze the batch dimension if it's 1
            if avg_residual.shape[0] == 1:
                avg_residual = avg_residual.squeeze(0)

            self.divergence_runner.repulsion_vectors.append(avg_residual.detach())

            print(f"Added new target vector from intermediate residual. Total vectors: {len(self.divergence_runner.repulsion_vectors)}")
        else:
            raise ValueError("intermediate_residual not found in additional_outputs")


        # need to recompute gradients
        if self.reset_state_ids is not None:
            state.update_from(selector.setup(self.reset_state_ids.clone()))
        else:
            state.update_from(selector.setup(state.ids))
        return False  # Continue training
