from pathlib import Path
from typing import List, Optional

import numpy as np
import torch
import wandb
from custom_dreamy.callbacks import WandbEPOCallback
from custom_dreamy.epo import add_fwd_hooks, build_pareto_frontier
from custom_dreamy.i_runner import IRunner, transfomer_embed_one_hot_input
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from eliciting_contexts.extracting_latents.extracting_mean_difference import (
    get_mean_difference,
)
from eliciting_contexts.fluent_dreaming.run_fluent_dreaming import (
    epo,
    log_fluent_dreaming_to_wandb,
)
from eliciting_contexts.utils.auth import setup_wandb_auth
from eliciting_contexts.utils.constants import DATA_DIR, DEVICE, WANDB_ENTITY
from eliciting_contexts.utils.load_models import (
    TorchLogisticRegression,
    load_classifier,
    load_finetuned_model,
)


# uses linear probe and tries to optimize for sandbagging
class FinetunedProbeRunner(IRunner):
    def __init__(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizerBase,
        layer: int,
        probe: TorchLogisticRegression,
    ):
        """
        Initialize the runner with a specific layer and probe.

        Args:
            model: The language model to run
            tokenizer: Tokenizer for the model
            layer: Layer to extract activations from
            probe: Linear probe to use for classification
        """
        self.model = model
        self.tokenizer = tokenizer
        self.layer = layer
        self.probe = probe

    def run_with_embeddings(
        self, input_embeddings: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, dict]:
        """
        Run the model with embedded inputs and return the probe prediction.

        Args:
            input_embeddings: Embedded input tensor

        Returns:
            Tuple of (target, logits, extra_data)
        """
        out = {}

        def get_target(
            module: torch.nn.Module,
            input: tuple[torch.Tensor, ...],
            output: torch.Tensor,
        ) -> None:
            # Get activations from the output instead of input
            # Convert to float32 and get the last position
            acts = output[0][:, -1, :].to(device=DEVICE)
            acts = output[0][:, -1, :].to(device=DEVICE)
            predictions = self.probe(acts)[:, 1]
            out["target"] = predictions

        with add_fwd_hooks(
            [
                (
                    self.model.model.layers[self.layer],
                    get_target,
                ),  # Hook on the layer directly
            ]
        ):
            model_logits = self.model(inputs_embeds=input_embeddings).logits

        return out["target"], model_logits, {}

    def run_with_one_hot(
        self, input_one_hot: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, dict]:
        """
        Run the model with one-hot encoded inputs and return the probe prediction.

        Args:
            input_one_hot: One-hot encoded input tensor

        Returns:
            Tuple of (target, logits, extra_data)
        """
        input_embed = self.one_hot_to_embed(input_one_hot)
        return self.run_with_embeddings(input_embed)

    def one_hot_to_embed(self, one_hot: torch.Tensor) -> torch.Tensor:
        """Convert one-hot encoded input to embedded input."""
        return transfomer_embed_one_hot_input(one_hot, self.model)

    def int_ids_to_embed(self, int_ids: torch.Tensor) -> torch.Tensor:
        """Convert integer token IDs to embedded input."""
        return self.model.get_input_embeddings()(int_ids)


class FinetunedCosineRunner(IRunner):
    def __init__(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizerBase,
        layer: int,
        reference_tensor_dict: dict,
    ):
        """
        Initialize the runner with a specific layer for cosine similarity.

        Args:
            model: The language model to run
            tokenizer: Tokenizer for the model
            layer: Layer to extract activations from
            reference_tensor_dict: Dictionary containing reference tensors by layer
        """
        self.model = model
        self.tokenizer = tokenizer
        self.layer = layer

        # Store the reference tensor dictionary
        self.reference_tensor_dict = reference_tensor_dict

    def run_with_embeddings(
        self, input_embeddings: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, dict]:
        """
        Run the model with embedded inputs and return cosine similarity.

        Args:
            input_embeddings: Embedded input tensor

        Returns:
            Tuple of (target, logits, extra_data)
        """
        out = {}

        def get_target(
            module: torch.nn.Module,
            input: tuple[torch.Tensor, ...],
            output: torch.Tensor,
        ) -> None:
            acts = output[0].to(device=DEVICE)

            reference_tensor = self.reference_tensor_dict[self.layer]
            reference_tensor = reference_tensor.to(acts.device)

            reference_expanded = (
                reference_tensor.unsqueeze(0).unsqueeze(0).expand(acts.shape[0], -1, -1)
            )
            dot_product = torch.sum(
                acts * reference_expanded, dim=-1
            )  # dot product along feature dimension
            reference_norm = torch.norm(reference_tensor, dim=-1, keepdim=False)
            similarity = dot_product / reference_norm

            similarity = similarity.mean(dim=1)

            out["target"] = similarity

        with add_fwd_hooks(
            [
                (
                    self.model.model.layers[self.layer],
                    get_target,
                ),  # Hook on the layer directly
            ]
        ):
            model_logits = self.model(inputs_embeds=input_embeddings).logits

        return out["target"], model_logits, {}

    def run_with_one_hot(
        self, input_one_hot: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, dict]:
        """
        Run the model with one-hot encoded inputs and return cosine similarity.

        Args:
            input_one_hot: One-hot encoded input tensor

        Returns:
            Tuple of (target, logits, extra_data)
        """
        input_embed = self.one_hot_to_embed(input_one_hot)
        return self.run_with_embeddings(input_embed)

    def one_hot_to_embed(self, one_hot: torch.Tensor) -> torch.Tensor:
        """Convert one-hot encoded input to embedded input."""
        return transfomer_embed_one_hot_input(one_hot, self.model)

    def int_ids_to_embed(self, int_ids: torch.Tensor) -> torch.Tensor:
        """Convert integer token IDs to embedded input."""
        return self.model.get_input_embeddings()(int_ids)


def get_fixed_positions(
    tokenizer,
    full_text: str = "",
    prefix_text: str = "",
    question_text: str = "",
) -> tuple[torch.Tensor, List[int]]:
    if full_text and (prefix_text or question_text):
        raise ValueError(
            "Both the full_text and a prefix_text/question_text were provided, when it should have been ither or."
        )

    if not full_text:
        full_text = f"""<start_of_turn>system log
{prefix_text}<end_of_turn>
<start_of_turn>user
{question_text}<end_of_turn>
<start_of_turn>model
"""

    # Convert full text to token ids
    full_ids = torch.tensor(tokenizer.encode(full_text))

    # Initialize empty lists
    fixed_positions = []
    reconstructed_text = ""

    # Find the boundaries of the variable section
    variable_start = "<start_of_turn>system log\n"
    variable_end = "<end_of_turn>\n<start_of_turn>user"
    # variable_start_min = full_text.find(variable_start)
    variable_start_max = full_text.find(variable_start) + len(variable_start)
    variable_end_min = full_text.find(variable_end)
    # variable_end_max = full_text.find(variable_end) + len(variable_end)

    # Process each token
    token_start_pos = 0
    for i, token_id in enumerate(full_ids):
        # Get the text for this token
        token_text = tokenizer.decode([token_id])
        # Skip special tokens in position tracking
        if token_text.strip() in ["<bos>", "<eos>", "<pad>"]:
            fixed_positions.append(True)  # Mark special tokens as fixed
            continue

        reconstructed_text += token_text

        if variable_start in reconstructed_text:
            token_end_pos = token_start_pos + len(token_text)

            if (
                token_start_pos >= variable_start_max
                and token_end_pos <= variable_end_min
            ):
                fixed_positions.append(False)
            elif token_start_pos >= variable_end_min:
                fixed_positions.append(True)

            elif (
                token_start_pos < variable_start_max
                and token_end_pos <= variable_start_max
            ):
                fixed_positions.append(True)

            else:
                print("\nTokens up to error:")
                for j in range(i + 1):
                    prev_token_text = tokenizer.decode([full_ids[j]])
                    print(f"Token {j}: '{prev_token_text}'")
                print(f"\nProblem occurred at token {i}: '{token_text}'")
                raise ValueError(
                    f"Token '{token_text}' spans boundary between fixed and variable regions"
                )

        else:
            fixed_positions.append(True)

        token_start_pos += len(token_text)

    # Debug print of tokens and their fixed positions
    for i, (token_id, is_fixed) in enumerate(zip(full_ids, fixed_positions)):
        token_text = tokenizer.decode([token_id])
        print(f"Token {i}: '{token_text}' - Fixed: {is_fixed}")

    return full_ids, fixed_positions


def finetuned_cosine_runner(
    layer: int,
    model: torch.nn.Module,
    extract_at_system_prompt_position: bool = False,
    act_position: int = -1,
) -> callable:
    filename_suffix = ""
    if extract_at_system_prompt_position:
        filename_suffix = "_system_prompt"
    elif act_position != -1:
        filename_suffix = f"_pos_{act_position}"

    path_to_differences = f"data/sandbagging/train-sandbagging-f1dpbban-step11264/latents/differences{filename_suffix}.h5"

    def f(*model_args, **model_kwargs) -> dict:
        out: dict = {}

        def get_target(
            module: torch.nn.Module,
            input: tuple[torch.Tensor, ...],
            output: tuple[torch.Tensor],
        ) -> None:
            # Get activations from the output instead of input
            # Convert to float32 and get the last position
            acts = output[0][:, -1, :]

            reference_tensor_dict = get_mean_difference(
                path_to_differences=path_to_differences
            )
            reference_tensor = reference_tensor_dict[layer]
            reference_tensor = reference_tensor.to(acts.device)
            # Calculate cosine similarity between activations and reference tensor
            cosine_sim = torch.nn.functional.cosine_similarity(
                acts, reference_tensor.unsqueeze(0).expand(acts.shape[0], -1), dim=1
            )
            out["target"] = cosine_sim

        with add_fwd_hooks(
            [
                (model.model.layers[layer], get_target),  # Hook on the layer directly
            ]
        ):
            out["logits"] = model(*model_args, **model_kwargs).logits
        return out

    return f


def run_sandbagging_fluent_dreaming(
    lora_model_id: str = "contextmodification/gemma-sandbagging-0w4j7rba-step1024",
    base_model_name: str = "google/gemma-2-2b-it",
    population_size: int = 8,
    full_prompt_text: str = "",
    prefix_text: str = "",
    question_text: str = "",
    layer: int = 15,
    x_penalty_min: float = 0.1,
    x_penalty_max: float = 10.0,
    iters: int = 300,
    return_logits: bool = True,
    wandb_run_name: str = None,
    use_wandb: bool = True,
    runner_type: str = "probe",
    device: str = "cuda",
    wandb_group: str = None,
    path_to_probe: Optional[str] = None,
):
    """
    Run fluent dreaming optimization on a finetuned model to analyze neuron activations.
    """

    if use_wandb:
        wandb_mode = "online"
    else:
        wandb_mode = "disabled"

    # Set up wandb authentication
    setup_wandb_auth()

    # Initialize wandb run
    run = wandb.init(
        project="fluent-dreaming",
        entity=WANDB_ENTITY,
        group=wandb_group,
        name=wandb_run_name,
        config={
            "lora_model_id": lora_model_id,
            "base_model_name": base_model_name,
            "population_size": population_size,
            "full_prompt_text": full_prompt_text,
            "prefix_text": prefix_text,
            "question_text": question_text,
            "layer": layer,
            "x_penalty_min": x_penalty_min,
            "x_penalty_max": x_penalty_max,
            "iters": iters,
            "return_logits": return_logits,
        },
        mode=wandb_mode,
        reinit=False,
    )

    # Load the finetuned model
    model, tokenizer = load_finetuned_model(
        lora_model_id=lora_model_id,
        base_model_name=base_model_name,
        device=device,
    )
    model.to(device=device)

    if runner_type == "probe":
        if path_to_probe is None:
            path_to_probe = Path(
                DATA_DIR
                / "sandbagging"
                / "train-sandbagging-f1dpbban-step11264"
                / "latents"
                / "probes"
                / f"probe_layer_{layer}.pkl"
            )
        probe = load_classifier(return_logits=return_logits, pickle_path=path_to_probe)
        runner = FinetunedProbeRunner(
            layer=layer, probe=probe, model=model, tokenizer=tokenizer
        )
    elif runner_type == "cosine":
        # Load reference tensor dictionary before instantiating the runner
        reference_tensor_dict = get_mean_difference()
        runner = FinetunedCosineRunner(
            layer=layer,
            model=model,
            tokenizer=tokenizer,
            reference_tensor_dict=reference_tensor_dict,
        )
    else:
        raise ValueError(f"Invalid runner type: {runner_type}")

    # Prepare initial prompt
    if full_prompt_text:
        initial_ids, fixed_positions = get_fixed_positions(
            tokenizer, full_text=full_prompt_text
        )
        initial_ids = initial_ids.unsqueeze(0).repeat(population_size, 1)
        seq_len = initial_ids.shape[-1]
    elif prefix_text and question_text:
        initial_ids, fixed_positions = get_fixed_positions(
            tokenizer, prefix_text=prefix_text, question_text=question_text
        )
        initial_ids = initial_ids.unsqueeze(0).repeat(population_size, 1)
        seq_len = initial_ids.shape[-1]
    elif question_text:
        initial_ids = None
        seq_len = 12
        fixed_positions = None
    else:
        raise ValueError(
            "Either the full text prompt or the question text (and optionally the prefix) must be provided."
        )

    # Run evolutionary optimization
    history = epo(
        runner,
        model,
        tokenizer,
        iters=iters,
        initial_ids=initial_ids,
        fixed_positions=fixed_positions,
        population_size=population_size,
        seq_len=seq_len,
        callbacks=(
            [
                WandbEPOCallback(
                    runner,
                    model,
                    tokenizer,
                    x_penalty_min=x_penalty_min,
                    x_penalty_max=x_penalty_max,
                )
            ]
            if use_wandb
            else []
        ),
        x_penalty_min=x_penalty_min,
        x_penalty_max=x_penalty_max,
        device=device,
    )

    # Log final results to W&B only if use_wandb is True
    log_fluent_dreaming_to_wandb(model, tokenizer, history, runner, "")

    # Print Pareto frontier results
    pareto = build_pareto_frontier(tokenizer, history)
    print("\nFinal outputs from Pareto frontier (ordered by cross-entropy):")
    ordering = np.argsort(pareto.xentropy)
    for i in ordering:
        print(f"\nText: {pareto.text[i]}")
        print(f"Cross-entropy: {pareto.xentropy[i]:.2f}")
        print(f"Target activation: {pareto.target[i]:.2f}")

    # Close the wandb run
    run.finish()

    return pareto, history
