"""
This file is based on the same file from the saebench repository.
It has been modified to use the correct hookpoint for the FFKV/Transcoder evals.
We marked the modified lines with # MODIFIED.
"""

import argparse
import gc
import os
import random
import shutil
import time
from dataclasses import asdict
from datetime import datetime

import torch
from sae_lens import SAE
from tqdm import tqdm
from transformer_lens import HookedTransformer

import ff_kv_sae.evals.sparse_probing.probe_training as probe_training
import ff_kv_sae.sae_bench_utils.activation_collection as activation_collection
import ff_kv_sae.sae_bench_utils.dataset_info as dataset_info
import ff_kv_sae.sae_bench_utils.dataset_utils as dataset_utils
import ff_kv_sae.sae_bench_utils.general_utils as general_utils
from ff_kv_sae.evals.sparse_probing.eval_config import SparseProbingEvalConfig
from ff_kv_sae.evals.sparse_probing.eval_output import (
    EVAL_TYPE_ID_SPARSE_PROBING,
    SparseProbingEvalOutput,
    SparseProbingLlmMetrics,
    SparseProbingMetricCategories,
    SparseProbingResultDetail,
    SparseProbingSaeMetrics,
)
from ff_kv_sae.sae_bench_utils import (
    get_eval_uuid,
    get_sae_bench_version,
    get_sae_lens_version,
)
from ff_kv_sae.sae_bench_utils.sae_selection_utils import (
    get_saes_from_regex,
)


def get_dataset_activations(
    dataset_name: str,
    config: SparseProbingEvalConfig,
    model: HookedTransformer,
    llm_batch_size: int,
    layer: int,
    hook_point: str,
    device: str,
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
    train_data, test_data = dataset_utils.get_multi_label_train_test_data(
        dataset_name,
        config.probe_train_set_size,
        config.probe_test_set_size,
        config.random_seed,
    )

    chosen_classes = dataset_info.chosen_classes_per_dataset[dataset_name]

    train_data = dataset_utils.filter_dataset(train_data, chosen_classes)
    test_data = dataset_utils.filter_dataset(test_data, chosen_classes)

    train_data = dataset_utils.tokenize_data_dictionary(
        train_data,
        model.tokenizer,  # type: ignore
        config.context_length,
        device,
    )
    test_data = dataset_utils.tokenize_data_dictionary(
        test_data,
        model.tokenizer,  # type: ignore
        config.context_length,
        device,
    )

    all_train_acts_BLD = activation_collection.get_all_llm_activations(
        train_data,
        model,
        llm_batch_size,
        layer,
        hook_point,
        mask_bos_pad_eos_tokens=True,
    )
    all_test_acts_BLD = activation_collection.get_all_llm_activations(
        test_data,
        model,
        llm_batch_size,
        layer,
        hook_point,
        mask_bos_pad_eos_tokens=True,
    )

    return all_train_acts_BLD, all_test_acts_BLD


def run_eval_single_dataset(
    dataset_name: str,
    config: SparseProbingEvalConfig,
    sae: SAE,
    model: HookedTransformer,
    layer: int,
    hook_point: str,
    device: str,
    artifacts_folder: str,
    save_activations: bool,
) -> tuple[dict[str, float], dict]:
    """config: eval_config.EvalConfig contains all hyperparameters to reproduce the evaluation."""

    per_class_results_dict = {}
    activations_filename = f"{dataset_name}_activations.pt".replace("/", "_")
    activations_path = os.path.join(artifacts_folder, activations_filename)

    # MODIFIED
    # Check if the SAE is a transcoder
    is_transcoder = getattr(sae.cfg, "is_transcoder", False)

    if is_transcoder:
        # For transcoder SAEs, use different hook points for LLM and SAE probes
        input_hook_point = getattr(
            sae.cfg, "input_hook_name", hook_point
        )  # For SAE (mlp_in)
        output_hook_point = getattr(
            sae.cfg, "output_hook_name", hook_point
        )  # For LLM probes (mlp_out)
    else:
        # For regular SAEs, both use the same hook point
        input_hook_point = output_hook_point = hook_point

    if not os.path.exists(activations_path):
        if config.lower_vram_usage:
            model = model.to(device)  # type: ignore

        # Get LLM activations from output hook point (what we want to predict)
        llm_train_acts_BLD, llm_test_acts_BLD = get_dataset_activations(
            dataset_name,
            config,
            model,
            config.llm_batch_size,  # type: ignore
            layer,
            output_hook_point,
            device,
        )

        # For transcoder SAEs, get SAE input activations separately
        if is_transcoder and input_hook_point != output_hook_point:
            sae_train_acts_BLD, sae_test_acts_BLD = get_dataset_activations(
                dataset_name,
                config,
                model,
                config.llm_batch_size,  # type: ignore
                layer,
                input_hook_point,
                device,
            )
            print(
                f"\n====== Collect input activations from {input_hook_point} ======\n"
            )
        else:
            sae_train_acts_BLD, sae_test_acts_BLD = (
                llm_train_acts_BLD,
                llm_test_acts_BLD,
            )
            print(
                f"\n====== Use the same hook point for both LLM and SAE probes: {output_hook_point} ======\n"
            )

        # Continue with LLM probe training
        all_train_acts_BD = activation_collection.create_meaned_model_activations(
            llm_train_acts_BLD
        )
        all_test_acts_BD = activation_collection.create_meaned_model_activations(
            llm_test_acts_BLD
        )

        # Train LLM probes using output activations
        llm_probes, llm_test_accuracies = probe_training.train_probe_on_activations(
            all_train_acts_BD,
            all_test_acts_BD,
            select_top_k=None,
            use_sklearn=False,
            batch_size=250,
            epochs=100,
            lr=1e-2,
        )

        # Save the results
        llm_results = {"llm_test_accuracy": llm_test_accuracies}

        # Add this new section to train top-k LLM probes
        for k in config.k_values:
            llm_top_k_probes, llm_top_k_test_accuracies = (
                probe_training.train_probe_on_activations(
                    all_train_acts_BD,
                    all_test_acts_BD,
                    select_top_k=k,
                )
            )
            llm_results[f"llm_top_{k}_test_accuracy"] = llm_top_k_test_accuracies

        # Save activations for later use
        acts = {
            "llm_train": llm_train_acts_BLD,
            "llm_test": llm_test_acts_BLD,
            "sae_train": sae_train_acts_BLD,
            "sae_test": sae_test_acts_BLD,
            "llm_results": llm_results,
            "is_transcoder": is_transcoder,
        }

        if save_activations:
            torch.save(acts, activations_path)
    else:
        # Load previously saved activations
        print(f"Loading activations from {activations_path}")
        acts = torch.load(activations_path)
        sae_train_acts_BLD = acts["sae_train"]
        sae_test_acts_BLD = acts["sae_test"]
        llm_results = acts["llm_results"]

    # Process SAE input activations through the SAE
    all_sae_train_acts_BF = activation_collection.get_sae_meaned_activations(
        sae_train_acts_BLD, sae, config.sae_batch_size
    )
    all_sae_test_acts_BF = activation_collection.get_sae_meaned_activations(
        sae_test_acts_BLD, sae, config.sae_batch_size
    )

    # Train SAE probes
    _, sae_test_accuracies = probe_training.train_probe_on_activations(
        all_sae_train_acts_BF,
        all_sae_test_acts_BF,
        select_top_k=None,
        use_sklearn=False,
        batch_size=250,
        epochs=100,
        lr=1e-2,
    )

    for key in list(sae_train_acts_BLD.keys()):
        del sae_train_acts_BLD[key]
        del sae_test_acts_BLD[key]

    if not config.lower_vram_usage:
        # This is optional, checking the accuracy of a probe trained on the entire SAE activations
        # We use GPU here as sklearn.fit is slow on large input dimensions, all other probe training is done with sklearn.fit
        _, sae_test_accuracies = probe_training.train_probe_on_activations(
            all_sae_train_acts_BF,
            all_sae_test_acts_BF,
            select_top_k=None,
            use_sklearn=False,
            batch_size=250,
            epochs=100,
            lr=1e-2,
        )
        per_class_results_dict["sae_test_accuracy"] = sae_test_accuracies
    else:
        per_class_results_dict["sae_test_accuracy"] = {"-1": -1}

        for key in all_sae_train_acts_BF.keys():
            all_sae_train_acts_BF[key] = all_sae_train_acts_BF[key].cpu()
            all_sae_test_acts_BF[key] = all_sae_test_acts_BF[key].cpu()

        torch.cuda.empty_cache()
        gc.collect()

    for llm_result_key, llm_result_value in llm_results.items():
        per_class_results_dict[llm_result_key] = llm_result_value

    for k in config.k_values:
        sae_top_k_probes, sae_top_k_test_accuracies = (
            probe_training.train_probe_on_activations(
                all_sae_train_acts_BF,
                all_sae_test_acts_BF,
                select_top_k=k,
            )
        )
        per_class_results_dict[f"sae_top_{k}_test_accuracy"] = sae_top_k_test_accuracies

    results_dict = {}
    for key, test_accuracies_dict in per_class_results_dict.items():
        average_test_acc = sum(test_accuracies_dict.values()) / len(
            test_accuracies_dict
        )
        results_dict[key] = average_test_acc

    return results_dict, per_class_results_dict


def run_eval_single_sae(
    config: SparseProbingEvalConfig,
    sae: SAE,
    model: HookedTransformer,
    device: str,
    artifacts_folder: str,
    save_activations: bool = True,
) -> tuple[dict[str, float | dict[str, float]], dict]:
    """hook_point: str is transformer lens format. example: f'blocks.{layer}.hook_resid_post'
    By default, we save activations for all datasets, and then reuse them for each sae.
    This is important to avoid recomputing activations for each SAE, and to ensure that the same activations are used for all SAEs.
    However, it can use 10s of GBs of disk space."""

    random.seed(config.random_seed)
    torch.manual_seed(config.random_seed)
    os.makedirs(artifacts_folder, exist_ok=True)

    results_dict = {}

    dataset_results = {}
    per_class_dict = {}
    for dataset_name in config.dataset_names:
        (
            dataset_results[f"{dataset_name}_results"],
            per_class_dict[f"{dataset_name}_results"],
        ) = run_eval_single_dataset(
            dataset_name,
            config,
            sae,
            model,
            sae.cfg.hook_layer,
            sae.cfg.hook_name,
            device,
            artifacts_folder,
            save_activations,
        )

    results_dict = general_utils.average_results_dictionaries(
        dataset_results, config.dataset_names
    )

    for dataset_name, dataset_result in dataset_results.items():
        results_dict[f"{dataset_name}"] = dataset_result

    if config.lower_vram_usage:
        model = model.to(device)  # type: ignore

    return results_dict, per_class_dict  # type: ignore


def run_eval(
    config: SparseProbingEvalConfig,
    selected_saes: list[tuple[str, SAE]] | list[tuple[str, str]],
    device: str,
    output_path: str,
    is_random_transformer: bool,
    model_instance: HookedTransformer | None,
    force_rerun: bool = False,
    clean_up_activations: bool = False,
    save_activations: bool = True,
    artifacts_path: str = "artifacts",
):
    """
    selected_saes is a list of either tuples of (sae_lens release, sae_lens id) or (sae_name, SAE object)

    If clean_up_activations is True, which means that the activations are deleted after the evaluation is done.
    You may want to use this because activations for all datasets can easily be 10s of GBs.
    Return dict is a dict of SAE name: evaluation results for that SAE."""
    eval_instance_id = get_eval_uuid()
    sae_lens_version = get_sae_lens_version()
    sae_bench_commit_hash = get_sae_bench_version()

    artifacts_folder = None
    os.makedirs(output_path, exist_ok=True)

    results_dict = {}

    llm_dtype = general_utils.str_to_dtype(config.llm_dtype)

    if model_instance is not None:
        model = model_instance
    else:
        model = HookedTransformer.from_pretrained_no_processing(
            config.model_name, device=device, dtype=llm_dtype
        )

    for sae_release, sae_object_or_id in tqdm(
        selected_saes, desc="Running SAE evaluation on all selected SAEs"
    ):
        sae_id, sae, sparsity = general_utils.load_and_format_sae(
            sae_release, sae_object_or_id, device
        )  # type: ignore
        sae = sae.to(device=device, dtype=llm_dtype)

        sae_result_path = general_utils.get_results_filepath(
            output_path, sae_release, sae_id, is_random_transformer
        )
        print(f"\n====== Current evaluated SAE: {sae_release} ======\n")

        if os.path.exists(sae_result_path) and not force_rerun:
            print(f"Skipping {sae_release}_{sae_id} as results already exist")
            continue

        artifacts_folder = os.path.join(
            artifacts_path,
            EVAL_TYPE_ID_SPARSE_PROBING,
            config.model_name,
            sae.cfg.hook_name,
        )

        sparse_probing_results, per_class_dict = run_eval_single_sae(
            config,
            sae,
            model,
            device,
            artifacts_folder,
            save_activations=save_activations,
        )
        eval_output = SparseProbingEvalOutput(
            eval_config=config,
            eval_id=eval_instance_id,
            datetime_epoch_millis=int(datetime.now().timestamp() * 1000),
            eval_result_metrics=SparseProbingMetricCategories(
                llm=SparseProbingLlmMetrics(
                    **{
                        k: v
                        for k, v in sparse_probing_results.items()
                        if k.startswith("llm_") and not isinstance(v, dict)
                    }
                ),
                sae=SparseProbingSaeMetrics(
                    **{
                        k: v
                        for k, v in sparse_probing_results.items()
                        if k.startswith("sae_") and not isinstance(v, dict)
                    }
                ),
            ),
            eval_result_details=[
                SparseProbingResultDetail(
                    dataset_name=dataset_name,
                    **result,
                )
                for dataset_name, result in sparse_probing_results.items()
                if isinstance(result, dict)
            ],
            eval_result_unstructured=per_class_dict,
            sae_bench_commit_hash=sae_bench_commit_hash,
            sae_lens_id=sae_id,
            sae_lens_release_id=sae_release,
            sae_lens_version=sae_lens_version,
            sae_cfg_dict=asdict(sae.cfg),
        )

        results_dict[f"{sae_release}_{sae_id}"] = asdict(eval_output)

        eval_output.to_json_file(sae_result_path, indent=2)

        gc.collect()
        torch.cuda.empty_cache()

    if clean_up_activations:
        if artifacts_folder is not None and os.path.exists(artifacts_folder):
            shutil.rmtree(artifacts_folder)

    return results_dict


def create_config_and_selected_saes(
    args,
) -> tuple[SparseProbingEvalConfig, list[tuple[str, str]]]:
    config = SparseProbingEvalConfig(
        model_name=args.model_name,
    )

    if args.llm_batch_size is not None:
        config.llm_batch_size = args.llm_batch_size
    else:
        config.llm_batch_size = activation_collection.LLM_NAME_TO_BATCH_SIZE[
            config.model_name
        ]

    if args.llm_dtype is not None:
        config.llm_dtype = args.llm_dtype
    else:
        config.llm_dtype = activation_collection.LLM_NAME_TO_DTYPE[config.model_name]

    if args.sae_batch_size is not None:
        config.sae_batch_size = args.sae_batch_size

    if args.random_seed is not None:
        config.random_seed = args.random_seed

    if args.lower_vram_usage:
        config.lower_vram_usage = True

    selected_saes = get_saes_from_regex(args.sae_regex_pattern, args.sae_block_pattern)
    assert len(selected_saes) > 0, "No SAEs selected"

    releases = set([release for release, _ in selected_saes])

    print(f"Selected SAEs from releases: {releases}")

    for release, sae in selected_saes:
        print(f"Sample SAEs: {release}, {sae}")

    return config, selected_saes


def arg_parser():
    parser = argparse.ArgumentParser(description="Run sparse probing evaluation")
    parser.add_argument("--random_seed", type=int, default=None, help="Random seed")
    parser.add_argument("--model_name", type=str, required=True, help="Model name")
    parser.add_argument(
        "--sae_regex_pattern",
        type=str,
        required=True,
        help="Regex pattern for SAE selection",
    )
    parser.add_argument(
        "--sae_block_pattern",
        type=str,
        required=True,
        help="Regex pattern for SAE block selection",
    )
    parser.add_argument(
        "--output_folder",
        type=str,
        default="eval_results/sparse_probing",
        help="Output folder",
    )
    parser.add_argument(
        "--force_rerun", action="store_true", help="Force rerun of experiments"
    )
    parser.add_argument(
        "--clean_up_activations",
        action="store_true",
        help="Clean up activations after evaluation",
    )
    parser.add_argument(
        "--save_activations",
        action="store_false",
        help="Save the generated LLM activations for later use",
    )
    parser.add_argument(
        "--llm_batch_size",
        type=int,
        default=None,
        help="Batch size for LLM. If None, will be populated using LLM_NAME_TO_BATCH_SIZE",
    )
    parser.add_argument(
        "--llm_dtype",
        type=str,
        default=None,
        choices=[None, "float32", "float64", "float16", "bfloat16"],
        help="Data type for LLM. If None, will be populated using LLM_NAME_TO_DTYPE",
    )
    parser.add_argument(
        "--sae_batch_size",
        type=int,
        default=None,
        help="Batch size for SAE. If None, will be populated using default config value",
    )
    parser.add_argument(
        "--lower_vram_usage",
        action="store_true",
        help="Lower GPU memory usage by doing more computation on the CPU. Useful on 1M width SAEs. Will be slower and require more system memory.",
    )
    parser.add_argument(
        "--artifacts_path",
        type=str,
        default="artifacts",
        help="Path to save artifacts",
    )

    return parser


if __name__ == "__main__":
    """
    python -m sae_bench.evals.sparse_probing.main \
    --sae_regex_pattern "sae_bench_pythia70m_sweep_standard_ctx128_0712" \
    --sae_block_pattern "blocks.4.hook_resid_post__trainer_10" \
    --model_name pythia-70m-deduped


    """
    args = arg_parser().parse_args()
    device = general_utils.setup_environment()

    start_time = time.time()

    config, selected_saes = create_config_and_selected_saes(args)

    print(selected_saes)

    # create output folder
    os.makedirs(args.output_folder, exist_ok=True)

    # run the evaluation on all selected SAEs
    results_dict = run_eval(
        config,
        selected_saes,
        device,
        args.output_folder,
        args.force_rerun,
        args.clean_up_activations,
        args.save_activations,
        artifacts_path=args.artifacts_path,
    )

    end_time = time.time()

    print(f"Finished evaluation in {end_time - start_time} seconds")


# Use this code snippet to use custom SAE objects
# if __name__ == "__main__":
#     import sae_bench.custom_saes.identity_sae as identity_sae
#     import sae_bench.custom_saes.jumprelu_sae as jumprelu_sae

#     """
#     python evals/sparse_probing/main.py
#     """
#     device = general_utils.setup_environment()

#     start_time = time.time()

#     random_seed = 42
#     output_folder = "eval_results/sparse_probing"

#     model_name = "gemma-2-2b"
#     hook_layer = 20

#     repo_id = "google/gemma-scope-2b-pt-res"
#     filename = f"layer_{hook_layer}/width_16k/average_l0_71/params.npz"
#     sae = jumprelu_sae.load_jumprelu_sae(repo_id, filename, hook_layer)
#     selected_saes = [(f"{repo_id}_{filename}_gemmascope_sae", sae)]

#     config = SparseProbingEvalConfig(
#         random_seed=random_seed,
#         model_name=model_name,
#     )

#     config.llm_batch_size = activation_collection.LLM_NAME_TO_BATCH_SIZE[config.model_name]
#     config.llm_dtype = activation_collection.LLM_NAME_TO_DTYPE[config.model_name]

#     # create output folder
#     os.makedirs(output_folder, exist_ok=True)

#     # run the evaluation on all selected SAEs
#     results_dict = run_eval(
#         config,
#         selected_saes,
#         device,
#         output_folder,
#         force_rerun=True,
#         clean_up_activations=False,
#         save_activations=True,
#     )

#     end_time = time.time()

#     print(f"Finished evaluation in {end_time - start_time} seconds")
