import argparse
import json
import os
from pathlib import Path
from typing import List

import wandb
from pydantic import BaseModel, Field

from eliciting_contexts.benchmark.external.tiny_stories.data.load_data import (
    StoryKeys,
    download_tiny_stories_dataset,
)
from eliciting_contexts.utils.call_openai import ImageChatHistory, call_model


class Sentences(BaseModel):
    """Structure model output."""

    thoughts: str = Field(
        ..., description="Think through the task and write notes here."
    )
    samples: List[str] = Field(..., description="Provide your new sentences here.")


if __name__ == "__main__":
    # Set up argument parser
    parser = argparse.ArgumentParser(description="Run EPO optimization for TinyStories")

    parser.add_argument("--output", type=str, help="Path to save output JSON results")
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.7,
        help="Temperature for model generation",
    )
    parser.add_argument(
        "--gpt_model_name", type=str, default="gpt-4o", help="OpenAI model to use"
    )
    parser.add_argument(
        "--num_sentences", type=int, default=6, help="Number of sentences to generate"
    )
    args = parser.parse_args()

    # Create outputs directory if it doesn't exist
    Path(os.path.dirname(args.output)).mkdir(parents=True, exist_ok=True)

    dataset = download_tiny_stories_dataset(
        hf_token=os.environ.get("HUGGINGFACE_HUB_TOKEN")
    )

    all_results = {}

    for idx, datum in enumerate(dataset["test"]):
        template = datum[StoryKeys.TEMPLATE]
        variable_context = datum[StoryKeys.VARIABLE_TEXT]
        undesired_text = datum[StoryKeys.UNDESIRED_TEXT]
        for desired_text in datum[StoryKeys.DESIRED_TEXT]:
            print(
                f"Optimizing for {desired_text} vs {undesired_text}\n{template.format(variable_context)}"
            )
            history = ImageChatHistory()
            history.add_user_msg(
                f"""Your task is to modify the following story to make the model predict the word '{desired_text}' instead of '{undesired_text}.\nYou can insert one sentence in the story and that sentence will be provided.\nHere is the story: {template}, where the {{0}} placeholder is where your sentence will be inserted.\n\nWrite down any thoughts you have about the task and then provide {args.num_sentences} sentences.\nEach new sentence should aim to make the model predict the word '{desired_text}' instead of '{undesired_text}'.\nMake whatever sentence you want to maximise the desired word!"""
            )
            success, gpt_new_text, error_msg = call_model(
                history,
                structured_output=Sentences,
                temperature=args.temperature,
                model=args.gpt_model_name,
            )
            if not success:
                print(f"Error: {error_msg}")
                break

            print(gpt_new_text.thoughts)
            print(gpt_new_text.samples)
            all_results[idx] = gpt_new_text.samples
            break  # only doing one desired text for now

    with open(args.output, "w") as f:
        json.dump(all_results, f, indent=2)

    wandb.finish()
