import argparse
import json
import os
from typing import Any, Dict

import torch
from custom_dreamy.callbacks import EPO_STORIES_HELPER_PROMPT, EPOHelperOutput
from transformer_lens import HookedTransformer

from eliciting_contexts.benchmark.external.tiny_stories.data.load_data import (
    StoryKeys,
    download_tiny_stories_dataset,
)
from eliciting_contexts.utils.call_openai import ImageChatHistory, call_model


def process_json(input_path, output_path):
    """
    Process a JSON file by loading its contents, copying to a new dict,
    and saving to a new file.

    Args:
        input_path (str): Path to the input JSON file
        output_path (str): Path to save the output JSON file
    """
    device = "cuda"
    evaluation_model = HookedTransformer.from_pretrained(
        "google/gemma-2-2b-it", dtype="bfloat16", device="cuda"
    )
    # Load the input JSON file
    with open(input_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    # Copy data to new dictionary
    new_data = {}

    dataset = download_tiny_stories_dataset(
        hf_token=os.environ.get("HUGGINGFACE_HUB_TOKEN")
    )

    for idx, datum in enumerate(dataset["test"]):

        texts = data[str(idx)]

        template = datum[StoryKeys.TEMPLATE]
        undesired_text = datum[StoryKeys.UNDESIRED_TEXT]
        desired_text = datum[StoryKeys.DESIRED_TEXT][0]

        items_with_targets = []
        for text in texts:

            new_story = template.format(text)

            story_ids = evaluation_model.tokenizer.encode(
                new_story, add_special_tokens=False
            )
            story_ids = torch.tensor(story_ids, device=device)
            story_ids = story_ids.unsqueeze(0)
            # Run model with initial IDs
            model_logits = evaluation_model(story_ids, return_type="logits")
            # Get token IDs for the first tokens of undesired and desired words
            undesired_tokens = evaluation_model.tokenizer.encode(
                undesired_text, add_special_tokens=False
            )
            desired_tokens = evaluation_model.tokenizer.encode(
                desired_text, add_special_tokens=False
            )
            undesired_first_token_id = undesired_tokens[0]
            desired_first_token_id = desired_tokens[0]

            # Extract logits from the last position
            last_pos_logits = model_logits[0, -1, :]

            # Get logit values
            undesired_logit = last_pos_logits[undesired_first_token_id].item()
            desired_logit = last_pos_logits[desired_first_token_id].item()

            # Calculate difference
            logit_diff = desired_logit - undesired_logit
            items_with_targets.append((logit_diff, text))

        # Sort by target value (highest to lowest)
        items_with_targets.sort(reverse=True)

        # Format the sorted items
        decoded_with_targets = [
            f"{text} (Target: {target:.4f})" for target, text in items_with_targets
        ]

        prompt_format_args: Dict[str, Any] = {
            "desired_text": desired_text,
            "undesired_text": undesired_text,
            "full_template": template.format("INSERT TEXT HERE"),
            "num_sentences": 8,
            "current_epo_str": chr(10).join(
                [
                    f"{idx + 1}. {example}"
                    for idx, example in enumerate(decoded_with_targets)
                ]
            ),
        }

        prompt_text = EPO_STORIES_HELPER_PROMPT.format(**prompt_format_args)
        print(prompt_text)

        chat_history = ImageChatHistory()
        chat_history.add_system_msg(prompt_text)

        assist_success, assist_content, error = call_model(
            chat_history,
            structured_output=EPOHelperOutput,
            temperature=1.0,
            model="gpt-4o",
        )
        new_data[str(idx)] = [d.text for d in assist_content.samples]
        print(new_data[str(idx)])

    # Save the new dictionary to the output file
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(new_data, f, indent=2, ensure_ascii=False)

    print(f"Processed JSON saved to {output_path}")
    return new_data


def main():
    parser = argparse.ArgumentParser(
        description="Process a JSON file by copying its contents"
    )
    parser.add_argument(
        "--input_file", required=True, help="Path to the input JSON file"
    )
    parser.add_argument(
        "--output_file", required=True, help="Path to save the output JSON file"
    )

    args = parser.parse_args()
    process_json(args.input_file, args.output_file)


if __name__ == "__main__":
    main()
