import argparse
import json
import os

import torch
from contextbench.load_data import SAEKeys, download_sae_dataset
from contextbench.utils import (
    get_neuronpedia_info,
    get_saelens_release_and_id,
    load_model_tlens,
    load_sae_saelens,
)
from tqdm import tqdm


class Info:
    def __init__(
        self,
        neuronpedia_description,
        sae_index,
        neuronpedia_id,
        density,
        vocab_diversity,
        local_vs_global,
        tags,
        necessary_context,
        necessary_condition,
        success_criterion,
        human_description,
        feature_grade,
    ):
        self.neuronpedia_description = neuronpedia_description
        self.sae_index = sae_index
        self.neuronpedia_id = neuronpedia_id
        self.density = density
        self.vocab_diversity = vocab_diversity
        self.local_vs_global = local_vs_global
        self.tags = tags
        self.necessary_context = necessary_context
        self.necessary_condition = necessary_condition
        self.success_criterion = success_criterion
        self.human_description = human_description
        self.feature_grade = feature_grade

    def to_dict(self):
        """Convert the Info object to a dictionary."""
        return {
            "neuronpedia_description": self.neuronpedia_description,
            "sae_index": self.sae_index,
            "neuronpedia_id": self.neuronpedia_id,
            "density": self.density,
            "vocab_diversity": self.vocab_diversity,
            "local_vs_global": self.local_vs_global,
            "tags": self.tags,
            "necessary_context": self.necessary_context,
            "necessary_condition": self.necessary_condition,
            "success_criterion": self.success_criterion,
            "human_description": self.human_description,
            "feature_grade": self.feature_grade,
        }


class Result:
    def __init__(
        self,
        input_text,
        mean_activation,
        max_activation,
        normalized_mean_activation,
        normalized_max_activation,
        cross_entropy,
        token_activations=None,
        tokenized_text=None,
        bos_token_found=False,
    ):
        self.input_text = input_text
        self.mean_activation = mean_activation
        self.max_activation = max_activation
        self.normalized_mean_activation = normalized_mean_activation
        self.normalized_max_activation = normalized_max_activation
        self.cross_entropy = cross_entropy
        self.token_activations = token_activations
        self.tokenized_text = tokenized_text
        # Create token-activation pairs, handling BOS token alignment
        self.token_activation_pairs = None
        if token_activations is not None and tokenized_text is not None:
            if bos_token_found:
                # Skip BOS token in tokenized_text when creating pairs
                self.token_activation_pairs = list(
                    zip(tokenized_text[1:], token_activations.tolist())
                )
            else:
                self.token_activation_pairs = list(
                    zip(tokenized_text, token_activations.tolist())
                )

    def __str__(self):
        return (
            f"==== Result Summary ====\n"
            f"Text: {self.input_text}\n"
            f"Mean Activation: {self.mean_activation:.4f}\n"
            f"Max Activation: {self.max_activation:.4f}\n"
            f"Cross entropy: {self.cross_entropy:.4f}\n"
            f"Normalized Mean Activation: {self.normalized_mean_activation:.4f}\n"
            f"Normalized Max Activation: {self.normalized_max_activation:.4f}\n"
        )

    def to_dict(self):
        """Convert the Result object to a dictionary."""
        result_dict = {
            "input_text": self.input_text,
            "mean_activation": self.mean_activation,
            "max_activation": self.max_activation,
            "cross_entropy": self.cross_entropy,
            "normalized_mean_activation": self.normalized_mean_activation,
            "normalized_max_activation": self.normalized_max_activation,
            "token_activations": (
                self.token_activations.tolist()
                if self.token_activations is not None
                else None
            ),
            "tokenized_text": self.tokenized_text,
            "token_activation_pairs": self.token_activation_pairs,
        }
        return result_dict

    @classmethod
    def from_dict(cls, data):
        """Create a Result object from a dictionary."""
        token_activations = (
            torch.tensor(data["token_activations"])
            if data.get("token_activations") is not None
            else None
        )
        return cls(
            input_text=data["input_text"],
            mean_activation=data["mean_activation"],
            max_activation=data["max_activation"],
            normalized_mean_activation=data["normalized_mean_activation"],
            normalized_max_activation=data["normalized_max_activation"],
            cross_entropy=data["cross_entropy"],
            token_activations=token_activations,
            tokenized_text=data.get("tokenized_text"),
        )


max_activations = {}


@torch.no_grad()
def get_metrics(
    result,
    evaluation_model,
    sae,
    sae_index,
    cross_entropy_model,
    device="cuda",
    neuronpedia_id=None,
    epo_target="MEAN",
):

    #### Compute targets
    out = {}
    bos_token_found = False

    def get_target(resid, hook):
        # Take mean across all tokens instead of just the last one
        inp_acts = resid

        pre_acts = []

        def store_pre_acts(acts: torch.Tensor, hook: str):
            pre_acts.append(acts)
            return acts

        sae.run_with_hooks(
            inp_acts,
            fwd_hooks=[("hook_sae_acts_pre", store_pre_acts)],
        )
        mean_pre_acts = pre_acts[0][:, :, sae_index]
        out["token_activations"] = mean_pre_acts[0]
        if bos_token_found:
            mean_pre_acts = mean_pre_acts[:, 1:]
        out["mean_activation"] = mean_pre_acts.mean(dim=1)
        out["max_activation"] = mean_pre_acts.max(dim=1)
        return resid

    ids = evaluation_model.tokenizer(
        result, return_tensors="pt", add_special_tokens=False
    )["input_ids"][0]

    # Get tokenized text before any modifications
    tokenized_text = evaluation_model.tokenizer.convert_ids_to_tokens(ids.tolist())

    if ids[0] == evaluation_model.tokenizer.bos_token_id:
        bos_token_found = True
        # ids = ids[1:]
    ids = ids.unsqueeze(0).to(device)

    # Use TransformerLens hook system
    with evaluation_model.hooks(fwd_hooks=[(sae.cfg.hook_name, get_target)]):
        evaluation_model(
            ids,
            return_type="logits",
        )

    mean_activation = out["mean_activation"][0].item()
    max_activation = out["max_activation"][0].item()
    token_activations = out["token_activations"]

    if (neuronpedia_id, sae_index) not in max_activations:
        # Get the maximum activation for this feature from top activating examples
        try:
            neuronpedia_info = get_neuronpedia_info(
                neuronpedia_id, index=int(sae_index)
            )
            top_window = neuronpedia_info.get_contexts_around_top_n_activations(
                n=10, window=15
            )
            top_activations = []
            for message in top_window:
                ids = evaluation_model.tokenizer(
                    message, return_tensors="pt", add_special_tokens=False
                )["input_ids"][0]
                ids = ids.unsqueeze(0).to(device)
                with evaluation_model.hooks(
                    fwd_hooks=[(sae.cfg.hook_name, get_target)]
                ):
                    evaluation_model(
                        ids,
                        return_type="logits",
                    )
                if EPO_TARGET == "MEAN":
                    top_activations.append(out["mean_activation"][0].item())
                elif EPO_TARGET == "MAX":
                    top_activations.append(out["max_activation"][0].item())
                else:
                    raise ValueError(f"Invalid EPO target: {EPO_TARGET}")
            feature_max_activation = max(top_activations)
        except Exception as e:
            print(f"Error getting neuronpedia info: {e}")
            feature_max_activation = (
                1.0  # Default to 1.0 if we can't get the max activation
            )
        max_activations[(neuronpedia_id, sae_index)] = feature_max_activation
    else:
        feature_max_activation = max_activations[(neuronpedia_id, sae_index)]

    # Normalize the activations
    normalized_mean_activation = mean_activation / feature_max_activation
    normalized_max_activation = max_activation / feature_max_activation

    #### Compute cross entropy
    ids_cross = cross_entropy_model.tokenizer(
        result, return_tensors="pt", add_special_tokens=True
    )["input_ids"][0]
    ids_cross = ids_cross.unsqueeze(0).to(device)
    model_logits_cross_entropy = cross_entropy_model(ids_cross, return_type="logits")

    # Convert fixed_positions to tensor and get indices where it's False
    logits_offset = model_logits_cross_entropy[:, :-1]
    cross_entropy = (
        torch.nn.CrossEntropyLoss(reduction="none")(
            logits_offset.reshape(-1, logits_offset.shape[-1]),
            ids_cross[:, 1:].reshape(-1),
        )
        .view(*logits_offset.shape[:2])
        .mean(dim=-1)
    )

    cross_entropy = cross_entropy.item()

    #### Get the predicted next tokens

    return Result(
        result,
        mean_activation,
        max_activation,
        normalized_mean_activation,
        normalized_max_activation,
        cross_entropy,
        token_activations=token_activations,
        tokenized_text=tokenized_text,
        bos_token_found=bos_token_found,
    )


if __name__ == "__main__":
    # Set up argument parser
    parser = argparse.ArgumentParser(description="Run EPO optimization for TinyStories")
    parser.add_argument(
        "--results_jsons",
        type=str,
        nargs="+",
        help="Paths to JSON files containing results to evaluate",
    )
    parser.add_argument(
        "--result_names",
        type=str,
        nargs="+",
        help="Names to identify each results JSON file (must match number of results_jsons)",
    )

    parser.add_argument(
        "--cross_entropy_model_name",
        type=str,
        default="google/gemma-2-2b",
        help="The model to use for calculating cross entropy",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="Device to run the model on (e.g., 'cuda', 'cpu')",
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default="bfloat16",
        help="Data type for model weights (e.g., 'bfloat16', 'float16')",
    )
    parser.add_argument(
        "--output_json",
        type=str,
        default="sae_evaluation_results.json",
        help="Path to save the evaluation results as a JSON file",
    )
    parser.add_argument(
        "--neuronpedia_description_model_name",
        type=str,
        default="gpt-4o-mini",
        help="The model name used for neuronpedia descriptions",
    )

    parser.add_argument(
        "--existing_json",
        type=str,
        default=None,
        help="Path to existing JSON file with results to reuse instead of recomputing",
    )
    parser.add_argument(
        "--recompute",
        type=str,
        nargs="+",
        default=None,
        help="List of names to recompute",
    )
    parser.add_argument(
        "--epo_target",
        type=str,
        default="MEAN",
        help="Target for EPO (overrides config) and what it gets normalized by",
    )

    parser.add_argument(
        "--load_with_no_processing",
        action="store_true",
        default=True,
        help="Load the model with no processing (overrides config)",
    )

    args = parser.parse_args()
    if args.epo_target:
        EPO_TARGET = args.epo_target
    # Verify that result_names matches results_jsons in length
    if not args.result_names:
        args.result_names = [f"result_set_{i}" for i in range(len(args.results_jsons))]
    if len(args.result_names) != len(args.results_jsons):
        raise ValueError(
            f"Number of result names ({len(args.result_names)}) must match number of result JSONs ({len(args.results_jsons)})"
        )

    all_json_results = []
    for json_path in args.results_jsons:
        try:
            with open(json_path, "r") as f:
                json_data = json.load(f)
                all_json_results.append(json_data)
            print(f"Loaded JSON file: {json_path}")
        except Exception as e:
            print(f"Error loading JSON file {json_path}: {e}")

    dataset = download_sae_dataset(hf_token=os.environ.get("HUGGINGFACE_HUB_TOKEN"))

    # Dictionary to store all results
    all_results = {}
    existing_results = None
    # Load existing results if provided
    if args.existing_json and os.path.exists(args.existing_json):
        try:
            with open(args.existing_json, "r") as f:
                existing_results = json.load(f)
            print(f"Loaded existing results from: {args.existing_json}")
        except Exception as e:
            print(f"Error loading existing results from {args.existing_json}: {e}")

    result_names = ["max_activating_examples"] + args.result_names
    all_json_results = [None] + all_json_results

    neuronpedia_groups = {}
    for idx, datum in enumerate(dataset["test"]):
        if idx != 3:
            continue
        neuronpedia_id = datum[SAEKeys.NEURONPEDIA_ID]
        if neuronpedia_id not in neuronpedia_groups:
            neuronpedia_groups[neuronpedia_id] = []
        neuronpedia_groups[neuronpedia_id].append((idx, datum))

    current_model_name = None

    cross_entropy_model = load_model_tlens(
        args.cross_entropy_model_name,
        args.device,
        args.dtype,
        args.load_with_no_processing,
    )

    for neuronpedia_id, examples in neuronpedia_groups.items():

        sae_lens_release, saelens_sae_id = get_saelens_release_and_id(neuronpedia_id)

        sae = load_sae_saelens(
            sae_lens_release, saelens_sae_id, args.device, args.dtype
        )

        model_name = sae.cfg.model_name

        if model_name != current_model_name:

            if model_name == args.cross_entropy_model_name:
                evaluation_model = cross_entropy_model
            else:
                evaluation_model = load_model_tlens(
                    model_name, args.device, args.dtype, args.load_with_no_processing
                )
            current_model_name = model_name

        for original_idx, datum in tqdm(examples):

            sae_index = datum[SAEKeys.INDEX]

            try:
                neuronpedia_info = get_neuronpedia_info(
                    neuronpedia_id, index=int(sae_index)
                )

                description = neuronpedia_info.get_explanation_by_model_name(
                    args.neuronpedia_description_model_name
                )

                top_activations = (
                    neuronpedia_info.get_contexts_around_top_n_activations(
                        n=10, window=15
                    )
                )
            except Exception as e:
                print(f"Error getting neuronpedia info: {e}")
                description = "NA"
                top_activations = []

            all_results[str(original_idx)] = {}

            # Create info object once per data point
            info = Info(
                neuronpedia_description=description,
                sae_index=sae_index,
                neuronpedia_id=neuronpedia_id,
                density=datum[SAEKeys.DENSITY],
                vocab_diversity=datum[SAEKeys.VOCAB_DIVERSITY],
                local_vs_global=datum[SAEKeys.LOCAL_VS_GLOBAL],
                tags=datum[SAEKeys.TAGS],
                necessary_context=datum[SAEKeys.NECESSARY_CONTEXT],
                necessary_condition=datum[SAEKeys.NECESSARY_CONDITION],
                success_criterion=datum[SAEKeys.SUCCESS_CRITERION],
                human_description=datum[SAEKeys.HUMAN_EXPLANATION],
                feature_grade=datum[SAEKeys.FEATURE_GRADE],
            )
            all_results[str(original_idx)]["info"] = info.to_dict()

            for json_idx, json_results in enumerate(all_json_results):
                result_name = result_names[json_idx]
                if json_results is None:

                    current_results = top_activations
                else:
                    if str(original_idx) not in json_results:
                        continue
                    current_results = json_results[str(original_idx)]

                all_results[str(original_idx)][result_name] = []

                for current_result in current_results:

                    result_exists = False
                    if (
                        existing_results is not None
                        and result_name not in args.recompute
                    ):
                        for existing_result in existing_results[str(original_idx)].get(
                            result_name, []
                        ):

                            if existing_result["input_text"] == current_result:
                                # We found an existing result, just use it
                                result_exists = True
                                # If it's the default result, store it for reference
                                break

                    if not result_exists:
                        result = get_metrics(
                            current_result,
                            evaluation_model,
                            sae,
                            sae_index,
                            cross_entropy_model,
                            device=args.device,
                            neuronpedia_id=neuronpedia_id,
                            epo_target=EPO_TARGET,
                        )

                        # Add result to the collection
                        all_results[str(original_idx)][result_name].append(
                            result.to_dict()
                        )
                    else:
                        all_results[str(original_idx)][result_name].append(
                            existing_result
                        )

    # Save all results to JSON file
    print(f"Saving all evaluation results to {args.output_json}")
    with open(args.output_json, "w") as f:
        json.dump(all_results, f, indent=2)
    print("Results saved successfully!")
