from typing import Any, Dict, List, Optional, Union

import torch
from datasets import Dataset, DatasetDict
from transformer_lens import HookedTransformer
from transformer_lens.utils import USE_DEFAULT_VALUE
from transformers import PreTrainedTokenizerBase

import wandb
from eliciting_contexts.benchmark.external.tiny_stories.run_epo import Config
from eliciting_contexts.benchmark.external.utils.logger import logger
from eliciting_contexts.benchmark.internal.tiny_stories.raw_data_modified import (
    simple_stories_dataset,
)
from eliciting_contexts.utils.constants import DEVICE


def analyze_prompt_predictions(
    model: HookedTransformer,
    tokenizer: PreTrainedTokenizerBase,  # Added type hint
    prompt: str,
    desired_texts: List[str],
    undesired_text: str,
    prepend_space_to_answers: bool = True,
    prepend_bos: Optional[bool] = USE_DEFAULT_VALUE,
) -> Dict[str, Any]:
    """
    Analyzes the model's next token predictions for a given prompt.

    Determines the most likely next word and calculates the logits and ranks
    for specified desired and undesired words.

    Args:
        model: The HookedTransformer model.
        tokenizer: The tokenizer associated with the model.
        prompt: The input prompt string.
        desired_texts: A list of desired word strings.
        undesired_text: The undesired word string.
        prepend_space_to_answers: Whether to prepend a space to answer words
                                 if not already present before tokenization.
        prepend_bos: Whether to prepend the BOS token to the prompt.

    Returns:
        A dictionary containing:
        - 'predicted_word': The most likely next word predicted by the model.
        - 'desired_details': A list of dicts [{'word': str, 'logit': float, 'rank': int}]
                             for each desired word.
        - 'undesired_details': A dict {'word': str, 'logit': float, 'rank': int}
                               for the undesired word.
    """
    prompt_tokens = model.to_tokens(prompt, prepend_bos=prepend_bos)
    prompt_tokens = prompt_tokens.to(model.cfg.device)

    with torch.no_grad():
        logits = model(prompt_tokens)
        # Logits for the prediction *after* the last token of the prompt
        next_token_logits = logits[0, -1, :]  # Shape: [vocab_size]

    sorted_logits, sorted_indices = next_token_logits.sort(dim=-1, descending=True)

    # --- 1. Find the predicted next word ---
    predicted_token_id = sorted_indices[0].item()
    predicted_word = model.to_string(predicted_token_id)

    # --- Helper function to get logit and rank for a specific word ---
    def get_token_details(
        token_text: str,
    ) -> Optional[Dict[str, Union[str, float, int]]]:
        """Calculates logit and rank for a single token text."""
        original_text = token_text
        if prepend_space_to_answers and not token_text.startswith(" "):
            token_text = " " + token_text

        # Tokenize the word
        try:
            # Use tokenizer directly for potentially multi-token words,
            # but we only care about the *first* token's properties here.
            # Add prepend_bos=False explicitly if needed by tokenizer.
            answer_tokens = tokenizer.encode(token_text, add_special_tokens=False)
            if not answer_tokens:
                logger.warning(
                    f"Word '{original_text}' (as '{token_text}') produced no tokens."
                )
                return None
            if len(answer_tokens) > 1:
                logger.warning(
                    f"Word '{original_text}' (as '{token_text}') tokenized to multiple tokens ({answer_tokens}). "
                    "Using properties of the first token only."
                )
            answer_token_id = answer_tokens[0]

        except Exception as e:
            logger.error(
                f"Error tokenizing word '{original_text}' (as '{token_text}'): {e}"
            )
            return None

        # Check if token_id is valid
        if answer_token_id < 0 or answer_token_id >= next_token_logits.shape[-1]:
            logger.warning(
                f"Word '{original_text}' (as '{token_text}') resulted in an invalid token ID: {answer_token_id}"
            )
            return None

        answer_logit = next_token_logits[answer_token_id].item()

        rank_tensor = (sorted_indices == answer_token_id).nonzero(as_tuple=True)[0]
        answer_rank = (
            rank_tensor.item() if rank_tensor.numel() > 0 else -1
        )  # Use -1 if not found (shouldn't happen for valid IDs)

        return {"word": original_text, "logit": answer_logit, "rank": answer_rank}

    # --- 2. Get details for desired words ---
    desired_details = []
    for word in desired_texts:
        details = get_token_details(word)
        if details:
            desired_details.append(details)

    # --- 3. Get details for the undesired word ---
    undesired_details = get_token_details(undesired_text)
    # Ensure undesired_details is always a dict, even if lookup failed
    if undesired_details is None:
        undesired_details = {"word": undesired_text, "logit": float("nan"), "rank": -1}

    # --- Construct the return dictionary ---
    results = {
        "predicted_word": predicted_word,
        "desired_details": desired_details,
        "undesired_details": undesired_details,
    }

    return results


def upload_hf_dataset(
    dataset: Dataset | DatasetDict,
    repo_id: str,
    hf_token: str,
    private: bool = False,
    **kwargs: Any,
) -> None:
    """
    Uploads a dataset to the Hugging Face Hub.

    Args:
        dataset: The dataset object (Dataset or DatasetDict) to upload.
        repo_id: The target repository ID on the Hugging Face Hub (e.g., "username/my-dataset").
        hf_token: Hugging Face API token. If None, it will try to use the token
                  from the environment or Hugging Face CLI login.
        private: If True, creates a private repository. Defaults to False.
        **kwargs: Additional keyword arguments passed to `dataset.push_to_hub()`.
    """
    if not repo_id or len(repo_id.split("/")) != 2:
        raise ValueError(
            "Invalid repo_id format. Expected 'username/repository_name' or 'org_name/repository_name'."
        )

    try:
        logger.info(
            f"Uploading dataset to Hugging Face Hub repository: '{repo_id}' (Private: {private})"
        )
        dataset.push_to_hub(
            repo_id=repo_id,
            token=hf_token,
            private=private,
            **kwargs,
        )
        logger.info(f"Successfully uploaded dataset to '{repo_id}'.")
    except (
        ValueError
    ) as e:  # Handles auth errors, invalid repo_id etc. from push_to_hub
        logger.error(
            f"Failed to upload dataset to '{repo_id}'. Check repository ID and authentication. Details: {e}"
        )
        raise ValueError from e
    except Exception as e:
        logger.error(f"An unexpected error occurred during upload to '{repo_id}': {e}")
        raise Exception from e


def create_tiny_stories_dataset_dict():
    # Ensure desired_text is always a list
    raw_data = simple_stories_dataset
    processed_data = []
    for i, item in enumerate(raw_data):
        try:
            template, variable, desired, undesired, story_type, human_answer = item  # Unpack 6 elements now
        except ValueError as e:
            logger.exception(f"Error unpacking item at index {i}: {item}")
            raise ValueError(
                f"Expected 6 elements, but got {len(item)} at index {i}"
            ) from e

        if isinstance(desired, str):
            desired = [desired]  # Convert single string to list
        processed_data.append((template, variable, desired, undesired, story_type, human_answer))

    column_names = [
        "template",
        "variable_text",
        "desired_text",
        "undesired_text",
        "story_type",
        "human_answer",
    ]
    data_dict = {
        name: [item[i] for item in processed_data]
        for i, name in enumerate(column_names)
    }
    dataset = Dataset.from_dict(data_dict)
    logger.info("Created DatasetDict.")
    return DatasetDict({"test": dataset})


def preprocess_tiny_stories_dataset(dataset_dict: DatasetDict, config: Config):
    """Preprocesses the TinyStories dataset by adding predicted word and
    logit/rank details for desired/undesired words.
    """

    logger.info(f"Loading model: {config.model_name}")
    model = HookedTransformer.from_pretrained_no_processing(
        config.model_name,
        device=DEVICE,
        trust_remote_code=True,
    )
    tokenizer = model.tokenizer
    if tokenizer.pad_token is None:
        logger.info("Tokenizer pad token not set, using EOS token as pad token.")
        tokenizer.pad_token = tokenizer.eos_token
    logger.info("Model and tokenizer loaded.")

    def process_example(example: dict[str, Any]) -> dict[str, Any]:
        """Processes a single example to add analysis results."""

        input_text = example["template"].format(example["variable_text"])
        desired_texts = example["desired_text"]
        undesired_text = example["undesired_text"]

        analysis_results = analyze_prompt_predictions(
            model,
            tokenizer,
            input_text,
            desired_texts,
            undesired_text,
        )

        return {
            "predicted_word": analysis_results["predicted_word"],
            "desired_details": analysis_results["desired_details"],
            "undesired_details": analysis_results["undesired_details"],
        }

    logger.info("Applying analysis function to dataset (non-batched)...")
    # Consider increasing num_proc if performance is an issue and safe
    processed_dataset_dict = dataset_dict.map(
        process_example,
        batched=False,  # Keep False as model prediction is inherently sequential per example
        num_proc=(
            config.num_proc if hasattr(config, "num_proc") else 1
        ),  # Use config or default
    )
    logger.info("Dataset preprocessing finished.")

    return processed_dataset_dict


def upload_tiny_stories_dataset(config: Config):
    dataset_dict = create_tiny_stories_dataset_dict()

    logger.info("Preprocessing dataset before upload...")
    dataset_dict = preprocess_tiny_stories_dataset(dataset_dict, config)
    logger.info("Dataset preprocessing complete.")

    logger.info(f"Attempting to upload processed dataset to {config.hf_repo_id}")
    upload_hf_dataset(
        dataset_dict,
        repo_id=config.hf_repo_id,
        hf_token=config.hf_token,
        private=True,
    )


if __name__ == "__main__":
    config = Config()
    config.wandb_project = "simple_stories_optimization"

    wandb.init(
        project=config.wandb_project,
        entity=config.wandb_entity,
        mode=config.wandb_mode,
    )
    config_dict = {
        attr: getattr(config, attr)
        for attr in dir(config)
        if not attr.startswith("__") and not callable(getattr(config, attr))
    }
    wandb.config.update(config_dict)
    logger.info("Wandb initialized.")

    upload_tiny_stories_dataset(config)
