import argparse
import json
import os

import numpy as np
import torch
from contextbench.load_data import StoryKeys, download_inpainting_stories_dataset
from contextbench.utils import tokenize_text_with_placeholder
from tqdm import tqdm
from transformer_lens import HookedTransformer


class Info:
    def __init__(
        self,
        template,
        variable_context,
        undesired_text,
        desired_text,
        story_type,
    ):
        self.template = template
        self.variable_context = variable_context
        self.undesired_text = undesired_text
        self.desired_text = desired_text
        self.story_type = story_type

    def to_dict(self):
        """Convert the Info object to a dictionary."""
        return {
            "template": self.template,
            "variable_context": self.variable_context,
            "undesired_text": self.undesired_text,
            "desired_text": self.desired_text,
            "story_type": self.story_type,
        }


class Result:
    def __init__(
        self,
        input_text,
        desired_logit,
        undesired_logit,
        logit_diff_improvement,
        logit_diff_normalized,
        logit_diff,
        desired_rank,
        undesired_rank,
        top_tokens,
        cross_entropy,
        predicted_text,
    ):
        self.input_text = input_text
        self.desired_logit = desired_logit
        self.undesired_logit = undesired_logit
        self.logit_diff_improvement = logit_diff_improvement
        self.logit_diff_normalized = logit_diff_normalized
        self.logit_diff = logit_diff
        self.desired_rank = desired_rank
        self.undesired_rank = undesired_rank
        self.top_tokens = top_tokens
        self.cross_entropy = cross_entropy
        self.predicted_text = predicted_text

    def __str__(self):
        return (
            f"==== Result Summary ====\n"
            f"Text: {self.input_text}\n"
            f"Logit Analysis:\n"
            f"  Desired logit: {self.desired_logit:.4f}\n"
            f"  Undesired logit: {self.undesired_logit:.4f}\n"
            f"  Logit difference: {self.logit_diff:.4f}\n"
            f"  Logit difference improvement: {self.logit_diff_improvement:.4f}\n\n"
            f"  Logit difference normalized: {self.logit_diff_normalized:.4f}\n\n"
            f"Ranking:\n"
            f"  Desired token rank: {self.desired_rank}\n"
            f"  Undesired token rank: {self.undesired_rank}\n\n"
            f"Top tokens: {', '.join(self.top_tokens)}\n\n"
            f"Cross entropy: {self.cross_entropy:.4f}\n\n"
            f"Predicted next tokens: '{self.predicted_text}'"
        )

    def to_dict(self):
        """Convert the Result object to a dictionary."""
        result_dict = {
            "input_text": self.input_text,
            "predicted_text": self.predicted_text,
            "logit_diff_improvement": self.logit_diff_improvement,
            "logit_diff_normalized": self.logit_diff_normalized,
            "logit_diff": self.logit_diff,
            "desired_logit": self.desired_logit,
            "undesired_logit": self.undesired_logit,
            "desired_rank": self.desired_rank,
            "undesired_rank": self.undesired_rank,
            #  "top_tokens": self.top_tokens,
            "cross_entropy": self.cross_entropy,
        }
        return result_dict

    @classmethod
    def from_dict(cls, data):
        """Create a Result object from a dictionary."""
        return cls(
            input_text=data["input_text"],
            desired_logit=data["desired_logit"],
            undesired_logit=data["undesired_logit"],
            logit_diff_improvement=data["logit_diff_improvement"],
            logit_diff_normalized=data["logit_diff_normalized"],
            logit_diff=data["logit_diff"],
            desired_rank=data["desired_rank"],
            undesired_rank=data["undesired_rank"],
            top_tokens=data.get("top_tokens", []),
            cross_entropy=data["cross_entropy"],
            predicted_text=data["predicted_text"],
        )


def get_metrics(
    template,
    result,
    undesired_text,
    desired_text,
    evaluation_model,
    cross_entropy_model,
    default_results=None,
    human_results=None,
    device="cuda",
):

    # remove trailing space
    if result[0] == " ":
        result = result[1:]

    #### Compute targets

    new_story = template.format(result)

    story_ids = evaluation_model.tokenizer.encode(new_story, add_special_tokens=False)
    story_ids = torch.tensor(story_ids, device=device)
    story_ids = story_ids.unsqueeze(0)
    # Run model with initial IDs
    model_logits = evaluation_model(story_ids, return_type="logits")
    # Get token IDs for the first tokens of undesired and desired words
    undesired_tokens = evaluation_model.tokenizer.encode(
        undesired_text, add_special_tokens=False
    )
    desired_tokens = evaluation_model.tokenizer.encode(
        desired_text, add_special_tokens=False
    )
    undesired_first_token_id = undesired_tokens[0]
    desired_first_token_id = desired_tokens[0]

    # Extract logits from the last position
    last_pos_logits = model_logits[0, -1, :]

    # Get logit values
    undesired_logit = last_pos_logits[undesired_first_token_id].item()
    desired_logit = last_pos_logits[desired_first_token_id].item()

    # Calculate difference
    logit_diff = desired_logit - undesired_logit

    if default_results is not None:
        logit_diff_improvement = logit_diff - default_results.logit_diff
    else:
        logit_diff_improvement = 0.0

    if human_results is not None:
        logit_diff_normalized = logit_diff_improvement / (
            human_results.logit_diff - default_results.logit_diff
        )
    elif default_results is not None:
        logit_diff_normalized = 1.0
    else:
        logit_diff_normalized = 0.0

    # Calculate ranks
    _, sorted_indices = torch.sort(last_pos_logits, descending=True)
    sorted_indices = sorted_indices.cpu().numpy()

    # Find ranks of desired and undesired tokens
    desired_rank = (
        int(np.where(sorted_indices == desired_first_token_id)[0][0]) + 1
    )  # +1 for 1-based ranking
    undesired_rank = int(np.where(sorted_indices == undesired_first_token_id)[0][0]) + 1

    # Get top 10 tokens
    top_k = 10
    top_tokens = []
    for i in range(top_k):
        token_id = sorted_indices[i]
        token_text = evaluation_model.tokenizer.decode([token_id])
        top_tokens.append(token_text)

    #### Compute cross entropy
    story_ids_cross, fixed_positions = tokenize_text_with_placeholder(
        cross_entropy_model.tokenizer,
        template,
        result,
        skip_special_tokens=False,
    )
    story_ids_cross = story_ids_cross.unsqueeze(0).to(device)
    model_logits_cross_entropy = cross_entropy_model(
        story_ids_cross, return_type="logits"
    )

    # Convert fixed_positions to tensor and get indices where it's False
    fixed_positions_tensor = torch.tensor(fixed_positions, device=device)
    unfixed_indices = torch.where(~fixed_positions_tensor)[0]

    logits_offset = model_logits_cross_entropy[:, :-1]
    nonfixed_positions_offset = unfixed_indices[unfixed_indices != 0] - 1

    cross_entropy = (
        torch.nn.CrossEntropyLoss(reduction="none")(
            logits_offset[:, nonfixed_positions_offset].reshape(
                -1, logits_offset.shape[-1]
            ),
            story_ids_cross[:, 1:][:, nonfixed_positions_offset].reshape(-1),
        )
        .view(logits_offset.shape[0], -1)
        .mean(dim=-1)
    )

    cross_entropy = cross_entropy.item()

    #### Get the predicted next tokens

    predicted_text = evaluation_model.generate(
        new_story,
        max_new_tokens=10,
        temperature=0.0,
        verbose=False,
    )[len(new_story) :]

    return Result(
        result,
        desired_logit,
        undesired_logit,
        logit_diff_improvement,
        logit_diff_normalized,
        logit_diff,
        desired_rank,
        undesired_rank,
        top_tokens,
        cross_entropy,
        predicted_text,
    )


if __name__ == "__main__":
    # Set up argument parser
    parser = argparse.ArgumentParser(description="Run EPO optimization for TinyStories")
    parser.add_argument(
        "--results_jsons",
        type=str,
        nargs="+",
        help="Paths to JSON files containing results to evaluate",
    )
    parser.add_argument(
        "--result_names",
        type=str,
        nargs="+",
        help="Names to identify each results JSON file (must match number of results_jsons)",
    )

    parser.add_argument(
        "--evaluation_model_name",
        type=str,
        default="google/gemma-2-2b-it",
        help="The model to use for evaluation",
    )
    parser.add_argument(
        "--cross_entropy_model_name",
        type=str,
        default="google/gemma-2-2b-it",
        help="The model to use for calculating cross entropy",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="Device to run the model on (e.g., 'cuda', 'cpu')",
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default="bfloat16",
        help="Data type for model weights (e.g., 'bfloat16', 'float16')",
    )
    parser.add_argument(
        "--output_json",
        type=str,
        default="evaluation_results.json",
        help="Path to save the evaluation results as a JSON file",
    )
    parser.add_argument(
        "--existing_json",
        type=str,
        default=None,
        help="Path to existing JSON file with results to reuse instead of recomputing",
    )
    parser.add_argument(
        "--recompute",
        type=str,
        nargs="+",
        default=None,
        help="List of names to recompute",
    )

    args = parser.parse_args()

    # Verify that result_names matches results_jsons in length
    if not args.result_names:
        args.result_names = [f"result_set_{i}" for i in range(len(args.results_jsons))]
    if len(args.result_names) != len(args.results_jsons):
        raise ValueError(
            f"Number of result names ({len(args.result_names)}) must match number of result JSONs ({len(args.results_jsons)})"
        )

    all_json_results = []
    for json_path in args.results_jsons:
        try:
            with open(json_path, "r") as f:
                json_data = json.load(f)
                all_json_results.append(json_data)
            print(f"Loaded JSON file: {json_path}")
        except Exception as e:
            print(f"Error loading JSON file {json_path}: {e}")

    dataset = download_inpainting_stories_dataset(
        hf_token=os.environ.get("HUGGINGFACE_HUB_TOKEN"),
    )

    evaluation_model = HookedTransformer.from_pretrained(
        args.evaluation_model_name, dtype=args.dtype, device=args.device
    )

    if args.cross_entropy_model_name != args.evaluation_model_name:
        cross_entropy_model = HookedTransformer.from_pretrained(
            args.cross_entropy_model_name, dtype=args.dtype, device=args.device
        )
    else:
        cross_entropy_model = evaluation_model

    # Dictionary to store all results
    all_results = {}
    existing_results = None
    # Load existing results if provided
    if args.existing_json and os.path.exists(args.existing_json):
        try:
            with open(args.existing_json, "r") as f:
                existing_results = json.load(f)
            print(f"Loaded existing results from: {args.existing_json}")
        except Exception as e:
            print(f"Error loading existing results from {args.existing_json}: {e}")

    result_names = ["default"] + ["human"] + args.result_names
    all_json_results = [None] + [None] + all_json_results

    for idx, datum in tqdm(enumerate(dataset["test"]), total=len(dataset["test"])):
        template = datum[StoryKeys.TEMPLATE]
        variable_context = datum[StoryKeys.VARIABLE_TEXT]
        undesired_text = datum[StoryKeys.UNDESIRED_TEXT]
        story_type = datum[StoryKeys.STORY_TYPE]
        human_answer = datum[StoryKeys.HUMAN_ANSWER]
        desired_text = datum[StoryKeys.DESIRED_TEXT][
            0
        ]  # only using the first desired text for now
        all_results[str(idx)] = {}

        # Create info object once per data point
        info = Info(
            template=template,
            variable_context=variable_context,
            undesired_text=undesired_text,
            desired_text=desired_text,
            story_type=story_type,
        )
        all_results[str(idx)]["info"] = info.to_dict()

        default_results = None
        human_results = None
        for json_idx, json_results in enumerate(all_json_results):
            result_name = result_names[json_idx]

            if json_results is None:
                if result_name == "default":
                    current_results = [variable_context]
                elif result_name == "human":
                    current_results = [human_answer]
                else:
                    continue
            else:
                if str(idx) not in json_results:
                    continue
                current_results = json_results[str(idx)]

            all_results[str(idx)][result_name] = []

            for current_result in current_results:
                # Check if we already have this result in our existing data
                result_exists = False
                if existing_results is not None and result_name not in args.recompute:
                    for existing_result in existing_results[str(idx)].get(
                        result_name, []
                    ):

                        if existing_result["input_text"] == current_result:
                            # We found an existing result, just use it
                            result_exists = True
                            # If it's the default result, store it for reference
                            if result_name == "default" and default_results is None:
                                default_results = Result.from_dict(existing_result)
                            # If it's the human result, store it for reference
                            if result_name == "human" and human_results is None:
                                human_results = Result.from_dict(existing_result)
                            break

                if not result_exists:
                    result = get_metrics(
                        template,
                        current_result,
                        undesired_text,
                        desired_text,
                        evaluation_model,
                        cross_entropy_model,
                        default_results=default_results,
                        human_results=human_results,
                        device=args.device,
                    )

                    # Add result to the collection
                    all_results[str(idx)][result_name].append(result.to_dict())

                    if default_results is None and result_name == "default":
                        default_results = result
                    if human_results is None and result_name == "human":
                        human_results = result
                else:
                    all_results[str(idx)][result_name].append(existing_result)

    # Save all results to JSON file
    print(f"Saving all evaluation results to {args.output_json}")
    with open(args.output_json, "w") as f:
        json.dump(all_results, f, indent=2)
    print("Results saved successfully!")
