# import pandas as pd
from datetime import datetime

import numpy as np
import pandas as pd
import torch
import transformers
from custom_dreamy.epo import add_fwd_hooks, build_pareto_frontier, epo
from custom_dreamy.history import History
from custom_dreamy.i_runner import IRunner, transfomer_embed_one_hot_input

import wandb
from eliciting_contexts.utils.auth import setup_wandb_auth
from eliciting_contexts.utils.constants import WANDB_ENTITY


class TokenRunner(IRunner):
    def __init__(self, model: transformers.PreTrainedModel, token_pos: int):
        """
        Initialize a TokenRunner that optimizes a specific output token.

        Args:
            model: The model to run
            token_pos: The token position to optimize for
        """
        self.model = model
        self.token_pos = token_pos

    def run_with_embeddings(
        self, input_embeddings: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, dict]:
        """
        Run the model with embedded inputs and return the target token's logit difference.

        Args:
            input_embeddings: Embedded input tensor

        Returns:
            Tuple of (target, logits, {})
        """
        model_logits = self.model(inputs_embeds=input_embeddings).logits

        # Get logits for the final position
        final_logits = model_logits[:, -1, :]
        target_logit = final_logits[:, self.token_pos]

        # Get the maximum logit excluding the target token
        mask = torch.ones_like(final_logits)
        mask[:, self.token_pos] = float("-inf")
        max_other_logit = torch.max(final_logits * mask, dim=-1).values

        # Calculate logit difference (target - max_other)
        target = target_logit - max_other_logit

        return target, model_logits, {}

    def one_hot_to_embed(self, one_hot: torch.Tensor) -> torch.Tensor:
        """Convert one-hot encoded input to embedded input."""
        return transfomer_embed_one_hot_input(one_hot, self.model)

    def int_ids_to_embed(self, int_ids: torch.Tensor) -> torch.Tensor:
        """Convert integer token IDs to embedded input."""
        return self.model.get_input_embeddings()(int_ids)


class NeuronRunner(IRunner):
    def __init__(self, model: transformers.PreTrainedModel, layer: int, neuron: int):
        """
        Initialize a NeuronRunner that optimizes a particular neuron in a layer.

        Args:
            model: The model to run
            layer: The layer index of the neuron
            neuron: The neuron index to optimize
        """
        self.model = model
        self.layer = layer
        self.neuron = neuron

    def run_with_embeddings(
        self, input_embeddings: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, dict]:
        """
        Run the model with embedded inputs and return the target neuron's activation.

        Args:
            input_embeddings: Embedded input tensor

        Returns:
            Tuple of (target, logits, {})
        """
        out = {}

        def get_target(
            module: torch.nn.Module,
            input: tuple[torch.Tensor, ...],
            output: torch.Tensor,
        ) -> None:
            out["target"] = input[0][:, -1, self.neuron]

        with add_fwd_hooks(
            [
                (self.model.model.layers[self.layer].mlp.fc2, get_target),
            ]
        ):
            logits = self.model(inputs_embeds=input_embeddings).logits

        return out["target"], logits, {}

    def one_hot_to_embed(self, one_hot: torch.Tensor) -> torch.Tensor:
        """Convert one-hot encoded input to embedded input."""
        return transfomer_embed_one_hot_input(one_hot, self.model)

    def int_ids_to_embed(self, int_ids: torch.Tensor) -> torch.Tensor:
        """Convert integer token IDs to embedded input."""
        return self.model.get_input_embeddings()(int_ids)


def log_fluent_dreaming_to_wandb(
    model: transformers.PreTrainedModel,
    tokenizer: transformers.PreTrainedTokenizer,
    history: History,
    runner: IRunner,
    target_name: str,
    run_name: str = None,
) -> None:
    """
    Logs Fluent Dreaming results to Weights & Biases.
    Uses existing W&B run if one exists, otherwise creates a new one.
    """

    # Only initialize if there's no active run
    if wandb.run is None:
        setup_wandb_auth()
        wandb.init(
            project="fluent-dreaming",
            entity=WANDB_ENTITY,
            name=run_name
            or f"fluent_dreaming_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
            config={
                "target_name": target_name,
                "model_name": model.config._name_or_path,
            },
        )

    # Log Pareto frontier texts and metrics
    pareto = build_pareto_frontier(tokenizer, history)
    ordering = np.argsort(pareto.xentropy)

    pareto_texts = []
    for i in ordering:
        pareto_texts.append(
            {
                "text": pareto.text[i],
                "cross_entropy": float(pareto.xentropy[i]),
                "activation": float(pareto.target[i]),
            }
        )

    wandb.log({"pareto_frontier": wandb.Table(dataframe=pd.DataFrame(pareto_texts))})

    # Don't finish the run if we didn't create it
    if run_name is not None:
        wandb.finish()

    # ... rest of the function ...


def example_usage():
    model_name = "microsoft/phi-2"
    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_name,
        low_cpu_mem_usage=True,
        torch_dtype="auto",
        use_cache=False,
        device_map="cuda",
    )
    tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

    # EX 1 optimize for a neuron in a layer
    # layer = 5
    # neuron = 1024
    # runner = NeuronRunner(model=model, layer=layer, neuron=neuron)

    # EX 2 optimize for a token
    target_text = "bye"  # You can change this to any word/phrase but ideally one token
    target_tokens = tokenizer.encode(target_text)
    token_pos = target_tokens[0]
    print(f"Target token (ID {token_pos}): '{tokenizer.decode([token_pos])}'")
    runner = TokenRunner(model=model, token_pos=token_pos)

    population_size = 8

    start_text = "Don't mind me, I just want to say hello"
    initial_ids = torch.tensor(tokenizer.encode(start_text))
    initial_ids = initial_ids.unsqueeze(0).repeat(population_size, 1)

    epo(
        runner,
        model,
        tokenizer,
        initial_ids=initial_ids,
        population_size=population_size,
        seq_len=initial_ids.shape[-1],
        device="cuda",
    )


# Example usage:
if __name__ == "__main__":
    example_usage()
