import argparse
import json
import os
from pathlib import Path
from typing import List

from pydantic import BaseModel, Field

import wandb
from eliciting_contexts.benchmark.external.sae_activation.data.load_data import (
    SAEKeys,
    download_sae_dataset,
)
from eliciting_contexts.fluent_dreaming.baselines.prompts import (
    NEURONPEDIA_SYSTEM_PROMPT,
)
from eliciting_contexts.utils.call_openai import ImageChatHistory, call_model
from eliciting_contexts.utils.neuronpedia_api import get_feature_description


class Sentences(BaseModel):
    """Structure model output."""

    thoughts: str = Field(
        ..., description="Think through the task and write notes here."
    )
    samples: List[str] = Field(..., description="Provide your new sentences here.")


if __name__ == "__main__":
    # Set up argument parser
    parser = argparse.ArgumentParser(
        description="Run EPO optimization for SAE activations"
    )

    parser.add_argument(
        "--output",
        type=str,
        help="Path to save output JSON results",
        default="results/sae_activation/gpt4o_sae_activation.json",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.7,
        help="Temperature for model generation",
    )
    parser.add_argument(
        "--gpt_model_name", type=str, default="gpt-4o", help="OpenAI model to use"
    )
    parser.add_argument(
        "--num_sentences", type=int, default=6, help="Number of sentences to generate"
    )
    parser.add_argument(
        "--neuronpedia_description_model_name",
        type=str,
        default="gpt-4o-mini",
        help="Model name for neuronpedia descriptions",
    )
    args = parser.parse_args()

    # Create outputs directory if it doesn't exist
    Path(os.path.dirname(args.output)).mkdir(parents=True, exist_ok=True)

    dataset = download_sae_dataset(hf_token=os.environ.get("HUGGINGFACE_HUB_TOKEN"))
    # print(dataset)
    all_results = {}
    for idx, datum in enumerate(dataset["test"]):
        print(datum)
        # category = datum[SAEKeys.CATEGORY]
        neuronpedia_id = datum[SAEKeys.NEURONPEDIA_ID]
        index = datum[SAEKeys.INDEX]

        history = ImageChatHistory()

        try:
            # Get neuronpedia info

            neuronpedia_info = get_feature_description(neuronpedia_id, index=index)

            description = neuronpedia_info.get_explanation_by_model_name(
                args.neuronpedia_description_model_name
            )

            top_activations = neuronpedia_info.get_contexts_around_top_n_activations(
                n=10, window=20
            )
        except Exception as e:
            print(f"Error getting neuronpedia info: {e}")
            continue

        openai_system_prompt = NEURONPEDIA_SYSTEM_PROMPT.format(
            description,
            chr(10).join(
                [f"{i + 1}. {example}" for i, example in enumerate(top_activations)]
            ),
            args.num_sentences,
        )
        history.add_user_msg(openai_system_prompt)
        success, gpt_new_text, error_msg = call_model(
            history,
            structured_output=Sentences,
            temperature=args.temperature,
            model=args.gpt_model_name,
        )
        if not success:
            print(f"Error: {error_msg}")
            break

        print(gpt_new_text.thoughts)
        print(gpt_new_text.samples)
        all_results[idx] = gpt_new_text.samples
        # break  # only doing one desired text for now

    with open(args.output, "w") as f:
        json.dump(all_results, f, indent=2)

    wandb.finish()
