import json
import os
import tempfile
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from custom_dreamy.epo import epo
from custom_dreamy.history import HistoryColumns
from transformer_lens import HookedTransformer
from transformer_lens.utils import USE_DEFAULT_VALUE

import wandb

# Reimport everything from system_utils to ensure we have the latest version
# import eliciting_contexts.fluent_dreaming.system_prompt_experiments.system_utils as system_utils
# reload(system_utils)
from eliciting_contexts.fluent_dreaming.system_prompt_experiments.system_utils import (
    TlensLogProbDiffRunner,
    process_text_with_placeholder,
)
from eliciting_contexts.fluent_dreaming.system_prompt_experiments.tiny_stories_dataset import (
    tiny_stories_dataset,
)
from eliciting_contexts.utils.constants import WANDB_ENTITY


def get_prompt_details(
    prompt: str,
    answer: Union[str, List[str]],
    model: HookedTransformer,
    prepend_space_to_answer: bool = True,
    prepend_bos: Optional[bool] = USE_DEFAULT_VALUE,
    top_k: int = 10,
) -> Dict[str, Union[int, float, List[Tuple[str, float]]]]:
    """
    Analyzes the model's prediction for the first token following a prompt,
    returning simplified details: answer rank/logit and top k tokens/logits.

    Args:
        prompt: The input prompt string.
        answer: The expected answer string (should correspond to a single token ideally).
                If a list is provided, only the first answer is analyzed.
        model: The HookedTransformer model.
        prepend_space_to_answer: Whether to prepend a space to the answer if not already present.
        prepend_bos: Whether to prepend the BOS token.
        top_k: The number of top token predictions to return.

    Returns:
        A dictionary containing:
        - 'answer_rank': Rank (int) of the provided answer token.
        - 'answer_logit': Logit (float) of the provided answer token.
        - 'top_tokens': List of tuples [(token_str, logit)] for the top k predictions.
    """
    # Ensure we have a list, and prioritize the first answer if multiple are given
    answers = [answer] if isinstance(answer, str) else answer
    first_answer = answers[0]

    if prepend_space_to_answer:
        if not first_answer.startswith(" "):
            first_answer = " " + first_answer

    prompt_tokens_tensor = model.to_tokens(prompt, prepend_bos=prepend_bos)
    # Only tokenize the first answer, assuming single token
    answer_tokens_tensor = model.to_tokens(first_answer, prepend_bos=False)

    # Expecting single token answer for this simplified function
    if answer_tokens_tensor.shape[1] > 1:
        print(
            f"Warning: Provided answer '{first_answer}' tokenized to multiple tokens. Only the first token will be analyzed."
        )
        answer_tokens_tensor = answer_tokens_tensor[:, :1]

    tokens_tensor = torch.cat((prompt_tokens_tensor, answer_tokens_tensor), dim=1)
    prompt_length = prompt_tokens_tensor.shape[1]

    # Ensure tensors are on the same device as the model
    tokens_tensor = tokens_tensor.to(model.cfg.device)

    # Run the model
    with torch.no_grad():
        logits = model(tokens_tensor)
        # Logits for the prediction *at* the position after the prompt
        logits_at_pos = logits[:, prompt_length - 1, :]
        probs_at_pos = logits_at_pos.softmax(dim=-1)

    # Sort probabilities to find ranks
    sorted_token_probs, sorted_token_indices = probs_at_pos.sort(
        dim=-1, descending=True
    )

    # --- Extract desired information ---

    # 1. Rank and Logit of the provided answer token
    answer_token_id = answer_tokens_tensor[0, 0].item()
    answer_logit = logits_at_pos[0, answer_token_id].item()
    # Find rank
    rank_tensor = (sorted_token_indices[0] == answer_token_id).nonzero(as_tuple=True)[0]
    answer_rank = (
        rank_tensor.item() if rank_tensor.numel() > 0 else -1
    )  # Rank is 0-indexed

    # 2. Top K tokens and their logits
    top_tokens_info = []
    for k in range(min(top_k, logits_at_pos.shape[-1])):
        top_k_token_id = sorted_token_indices[0, k].item()
        top_k_logit = logits_at_pos[0, top_k_token_id].item()
        top_k_token_str = model.to_string(top_k_token_id)
        top_tokens_info.append((top_k_token_str, top_k_logit))

    # --- Construct the simplified return dictionary ---
    results = {
        "answer_rank": answer_rank,
        "answer_logit": answer_logit,
        "top_tokens": top_tokens_info,
    }

    return results


def generate_first_word(
    model: HookedTransformer,
    tokenizer,
    input_text: str,
    temperature: float = 0.0,
    max_tokens: int = 10,
    verbose: bool = False,
) -> Tuple[str, List[int], str]:
    """
    Generate text using the model and return the tokens for just the first word.

    Args:
        model: The model to use for generation
        tokenizer: The tokenizer
        input_text: The input text to generate from
        temperature: Temperature for generation (higher = more random)
        max_tokens: Maximum number of tokens to generate
        verbose: Whether to print debug info

    Returns:
        Tuple of (first_word_text, first_word_tokens, new_text)
    """
    # Generate text
    generated_text = model.generate(
        input_text,
        max_new_tokens=max_tokens,
        temperature=temperature,
        verbose=verbose,
    )
    # Get details about the prediction for "safe"
    prompt_details = get_prompt_details(input_text.rstrip(), " safe", model)
    print("\nPrompt Details for 'safe':")
    # Access the simplified dictionary structure
    print(
        f"  Answer Token: 'safe' | Rank: {prompt_details['answer_rank']} Logit: {prompt_details['answer_logit']:.2f}"
    )
    print("  Top Predictions:")
    for rank, (token, logit) in enumerate(prompt_details["top_tokens"]):
        print(f"    Rank {rank}: |{token}| Logit: {logit:.2f}")

    # Extract just the newly generated text
    new_text = generated_text[len(input_text) :]

    # Split into words and get the first one
    words = new_text.strip().split()
    if not words:
        return (
            "",
            [],
            new_text,
        )  # Return empty string and list if no new words generated

    first_word = words[0]

    # Check if there's a leading space in the *new* text
    has_leading_space = new_text.startswith(" ")
    tokenized_first_word = first_word  # Start with the word itself

    if has_leading_space:
        # If the *new* text starts with a space, prepend it for tokenization
        tokenized_first_word = " " + first_word

    # Get tokens for just the first word WITHOUT special tokens
    # Use the potentially space-prepended version for accurate tokenization
    first_word_tokens = tokenizer.encode(tokenized_first_word, add_special_tokens=False)

    # Return the first word *without* the artificial leading space if it was added
    return first_word, first_word_tokens, new_text


def optimize_for_word_prediction(
    model: HookedTransformer,
    tokenizer,
    desired_word: str,
    undesired_word: str,
    start_context: str,
    story_template: str,
    device: str = "cuda",
    population_size: int = 8,
    explore_per_pop: int = 32,
    restart_frequency: int = 51,
    batch_size: int = 32,
    iters: int = 100,
    verbose: bool = True,
) -> Dict[str, Any]:
    """
    Optimize context to make model predict a desired word instead of an undesired word.

    Args:
        model: The model to use
        tokenizer: The tokenizer for the model
        desired_word: The word we want the model to predict
        undesired_word: The word we want to avoid predicting
        start_context: Initial context to start optimization from
        story_template: Template string with {0} placeholder for context
        device: Device to run optimization on
        population_size: Population size for EPO
        explore_per_pop: Exploration parameter for EPO
        restart_frequency: Restart frequency for EPO
        batch_size: Batch size for processing
        iters: Number of optimization iterations
        verbose: Whether to print progress

    Returns:
        dict: Dictionary containing optimization results
    """
    # Set up runner and calculator
    runner = TlensLogProbDiffRunner(
        model,
        tokenizer,
        undesired_word,
        desired_word,
        max_num_tokens=1,
        literal_diff=True,
    )
    # literal_diff_calculator = runner.log_prob_diff_calculator

    # Format story start and compute initial diff
    story_start = story_template.format(start_context)
    # literal_diff = compute_log_prob_diff(
    #     model, tokenizer, literal_diff_calculator, story_start
    # )

    # Get model's original prediction
    first_word_text, first_word_tokens, new_text = generate_first_word(
        model, tokenizer, story_start
    )
    print(f"\nModel's original prediction: '{new_text}'")

    # Process text with placeholder for optimization
    initial_ids, fixed_positions = process_text_with_placeholder(
        tokenizer, story_template, start_context
    )

    # Set up EPO optimization
    callbacks = []
    initial_ids = initial_ids.unsqueeze(0).repeat(population_size, 1)
    seq_len = initial_ids.shape[-1]

    # Run optimization
    history = epo(
        runner,
        model,
        tokenizer,
        iters=iters,
        initial_ids=initial_ids,
        fixed_positions=fixed_positions,
        population_size=population_size,
        seq_len=seq_len,
        explore_per_pop=explore_per_pop,
        restart_frequency=restart_frequency,
        callbacks=callbacks,
        batch_size=batch_size,
        device=device,
        verbose=verbose,
    )

    # Process history results
    history_df = history.to_dataframe(tokenizer, iter=iters - 1, child=0)

    # Filter relevant columns and remove duplicates
    filtered_df = history_df[
        [
            HistoryColumns.TEXT,
            HistoryColumns.TARGET,
            HistoryColumns.XENTROPY,
            HistoryColumns.TOKEN_IDS,
        ]
    ]
    filtered_df = filtered_df.drop_duplicates(subset=[HistoryColumns.TEXT])

    # Extract and process tokens
    all_tokens = filtered_df[HistoryColumns.TOKEN_IDS].tolist()
    fixed_positions_tensor = torch.tensor(fixed_positions)
    cleaned_outputs = []
    targets = filtered_df[HistoryColumns.TARGET].tolist()
    xentropies = filtered_df[HistoryColumns.XENTROPY].tolist()
    new_predictions = []

    for tokens in all_tokens:
        tokens = list(map(int, tokens))
        tokens = torch.tensor(tokens)
        user_tokens = tokens[~fixed_positions_tensor]
        text = tokenizer.decode(user_tokens)
        cleaned_outputs.append(text)
        _, _, new_text = generate_first_word(
            model, tokenizer, story_template.format(text)
        )
        new_predictions.append(new_text)
    success_score = (
        1
        if any(desired_word.lower() in pred.lower() for pred in new_predictions)
        else 0
    )

    for target, xentropy, text, new_prediction in zip(
        targets, xentropies, cleaned_outputs, new_predictions
    ):
        print(
            f"Target: {target}, Xentropy: {xentropy}\nText: {text}\nModel now predicts {new_prediction}"
        )
    print(f"Success score: {success_score}")

    # Return results dictionary
    return {
        "cleaned_outputs": cleaned_outputs,
        "targets": targets,
        "xentropies": xentropies,
        "new_predictions": new_predictions,
        "success_score": success_score,
    }


# WandB config setup
class Config:
    """Configuration class for experiment parameters"""

    def __init__(self):
        # Hardware settings
        self.device = "cuda"
        self.dtype = "bfloat16"
        self.model_name = "google/gemma-2-2b-it"

        # EPO parameters
        self.population_size = 8
        self.explore_per_pop = 32
        self.restart_frequency = 51
        self.batch_size = 32
        self.iters = 100

        # WandB settings
        self.wandb_project = "simple_stories_optimization"
        self.wandb_entity = WANDB_ENTITY
        self.wandb_mode = "online"

        # HF settings
        self.hf_repo_id = "contextmodification/simple_stories_dataset"
        self.hf_token = os.environ.get("HF_TOKEN")


def init_wandb(config: Config):
    """Initialize Weights & Biases with config parameters"""
    wandb.init(
        project=config.wandb_project,
        entity=config.wandb_entity,
        mode=config.wandb_mode,
    )
    config_dict = {
        attr: getattr(config, attr)
        for attr in dir(config)
        if not attr.startswith("__") and not callable(getattr(config, attr))
    }
    wandb.config.update(config_dict)


if __name__ == "__main__":
    config = Config()

    init_wandb(config)

    model = HookedTransformer.from_pretrained(
        config.model_name, dtype=config.dtype, device=config.device
    )
    tokenizer = model.tokenizer

    all_results = []

    for idx, (
        story_template,
        default_context,
        default_word,
        negative_word,
    ) in enumerate(tiny_stories_dataset):
        desired_word = negative_word
        undesired_word = default_word
        start_context = default_context
        print(f"{story_template.format(start_context)}")

        results = optimize_for_word_prediction(
            model=model,
            tokenizer=tokenizer,
            iters=config.iters,
            desired_word=desired_word,
            undesired_word=undesired_word,
            start_context=start_context,
            story_template=story_template,
            population_size=config.population_size,
            explore_per_pop=config.explore_per_pop,
            restart_frequency=config.restart_frequency,
            batch_size=config.batch_size,
            verbose=True,
        )

        # Store results with metadata
        experiment_results = {
            "story_template": story_template,
            "default_context": default_context,
            "desired_word": desired_word,
            "undesired_word": undesired_word,
            "optimization_results": results,
        }
        all_results.append(experiment_results)

    # Save results to a JSON file with unique name
    base_filename = "optimization_results"
    filename = f"{base_filename}.json"
    n = 1
    while os.path.exists(filename):
        filename = f"{base_filename}_{n:02d}.json"
        n += 1

    with open(filename, "w") as f:
        json.dump(all_results, f, indent=2)

    # Also save to wandb artifact
    results_artifact = wandb.Artifact(
        "optimization_results",
        type="results",
        description="Results from simple stories optimization",
    )
    with tempfile.NamedTemporaryFile(mode="w", suffix=".json") as tmp:
        json.dump(all_results, tmp)
        tmp.flush()
        results_artifact.add_file(tmp.name, "optimization_results.json")
    wandb.log_artifact(results_artifact)

    print("Results saved to optimization_results.json and logged to WandB")

    # Don't forget to close wandb run
    wandb.finish()
