from typing import List

import pandas as pd  # Add pandas import
from custom_dreamy.history import HistoryColumns, log_dataframe_to_wandb
from custom_dreamy.i_runner import IRunner
from datasets import Dataset
from pydantic import BaseModel, Field
from tqdm import tqdm  # Add tqdm import
from transformers import PreTrainedTokenizer

from eliciting_contexts.fluent_dreaming.sae import SAERunnerTLens
from eliciting_contexts.utils.call_openai import ImageChatHistory, call_model
from eliciting_contexts.utils.load_models import load_sae_and_model


class MultiSampleOutput(BaseModel):
    """Structure for getting multiple samples in a single API call."""

    thought: str = Field(..., description="Think through the task and write notes here")
    samples: List[str] = Field(..., description="A list of attempts at the task")


def generate_openai_samples(
    runner: IRunner,
    tokenizer: PreTrainedTokenizer,
    dataset_name: str = "",
    hf_token: str | None = None,
    system_prompt: str = "",
    user_prompt: str = "",
    num_calls: int = 10,
    num_samples_per_call: int = 5,  # How many samples to get per API call
    temperature: float = 1.0,
    model: str = "gpt-4o-mini",
    device: str = "cuda",
    save_to: str = "hf",  # Options: "hf", "wandb", or "none"
):
    history = ImageChatHistory()
    if system_prompt:
        history.add_system_msg(system_prompt)
    if user_prompt:
        history.add_user_msg(user_prompt)

    # Initialize DataFrame to store results
    results_df = pd.DataFrame(columns=["text", "thought"])

    for _ in tqdm(range(num_calls), desc="Generating samples"):
        success, content, error = call_model(
            history,
            structured_output=MultiSampleOutput,
            temperature=temperature,
            model=model,
        )

        if success:
            # Create entries for each sample with the corresponding thought
            for i, sample in enumerate(content.samples):
                # Tokenize the sample
                print(f"==== sample {i} ====")
                print(sample)
                input_ids = tokenizer(sample, return_tensors="pt").input_ids
                input_ids = input_ids.to(device)

                # Get target activations from the runner
                target, logits, _ = runner.run_with_embeddings(
                    runner.int_ids_to_embed(input_ids)
                )

                # TODO fixed positions! (needs better refactoring )
                xentropy = runner.calc_xentropy(
                    logits,
                    input_ids,
                )

                results_df = pd.concat(
                    [
                        results_df,
                        # TODO bit slow detaching each time but shouldn't really matter at scale (for now??)
                        pd.DataFrame(
                            {
                                "text": sample,
                                "thought": content.thought,
                                HistoryColumns.TARGET: target.detach()
                                .float()
                                .cpu()
                                .numpy(),
                                HistoryColumns.XENTROPY: xentropy.detach()
                                .float()
                                .cpu()
                                .numpy(),
                            }
                        ),
                    ],
                    ignore_index=True,
                )
        else:
            print(f"Error generating samples: {error}")

    # Save results based on the specified destination
    if save_to == "hf":
        # Save results to Hugging Face Hub
        hf_dataset = Dataset.from_pandas(results_df)
        print("saving to", dataset_name)
        # Save to Hugging Face Hub
        hf_dataset.push_to_hub(
            dataset_name,
            split="train",
            private=True,
            token=hf_token,
        )
    elif save_to == "wandb":
        # Log the DataFrame only to wandb summary
        log_dataframe_to_wandb(df=results_df, summary_key="openaicall_samples")
    elif save_to == "none":
        print("Results not saved (save_to=none)")
    else:
        print(f"Saving skipped (save_to={save_to})")

    # Return the DataFrame
    return results_df


# TODO add args
def generate_openai_samples_sae(
    system_prompt: str = "",
    user_prompt: str = "",
    num_calls: int = 10,
    num_samples_per_call: int = 5,  # How many samples to get per API call
    temperature: float = 1.0,
    model: str = "gpt-4o-mini",
    device: str = "cuda",
    feature: int = 321,
    dataset_name: str = "model_with_neuronpedia",
    hf_token: str | None = None,
    save_to: str = "hf",  # Add save_to parameter
):
    pytorch_model, tokenizer, sae, _, _ = load_sae_and_model(device=device)

    runner = SAERunnerTLens(model=pytorch_model, sae=sae, feature=feature)

    generate_openai_samples(
        runner=runner,
        tokenizer=tokenizer,
        system_prompt=system_prompt,
        user_prompt=user_prompt,
        num_calls=num_calls,
        num_samples_per_call=num_samples_per_call,
        temperature=temperature,
        model=model,
        dataset_name=dataset_name,
        hf_token=hf_token,
        device=device,
        save_to=save_to,  # Pass the save_to parameter
    )


# Example usage:
if __name__ == "__main__":
    # Test with dummy data
    num_samples_per_call = 5
    from eliciting_contexts.fluent_dreaming.baselines.prompts import (
        NEURONPEDIA_SYSTEM_PROMPT,
        animal_shelter_examples,
        animal_shelter_explanation,
    )

    system_prompt = NEURONPEDIA_SYSTEM_PROMPT.format(
        animal_shelter_explanation, animal_shelter_examples, num_samples_per_call
    )

    generate_openai_samples_sae(
        system_prompt=system_prompt,
        num_calls=60,
        num_samples_per_call=num_samples_per_call,
        dataset_name="model_with_neuronpedia",
        model="gpt-4o",
    )
