import argparse
import time
from tqdm import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd

# Import necessary components from your project
from ar import ActivationReasoning, LogicConfig
from ar.utils import load_implicit_train_data, load_train_data

from countries import (
    tri_color_countries_subset,
    prompt_r2c_prefix,
    prompt_r2c_CoT_prefix,
)

# This is the text that gets removed from the original prompts for the baseline
prompt_body_to_remove = "Trains are painted by their country’s flag of origin.\nYour goal is to identify the correct country based on the train’s distinctive color coding.\n"
prompt_cot_body_to_remove = "Due to the train's distinctive color coding, I believe it originates from the country of"
# --- Refactored timing functions for sequential execution ---

BATCHSIZE = 1


def time_ar_model(
    model_name, steering, dataset_type, test_prompts, model_hyp, num_samples
):
    """
    Loads, configures, times (in a single batch), and unloads the ActivationReasoning model.
    Uses the RAW, unmodified prompts.
    """
    print("\n--- 🧠 Timing ActivationReasoning Model (on raw prompts) ---")

    layer = 20 if "gemma" in model_name else 23
    model_config = {
        "model_name": model_name,
        "layer": layer,
        "sae_name": "EleutherAI/sae-llama-3.1-8b-64x"
        if "llama" in model_name
        else "gemma-scope-9b-pt-res-canonical",
        "hookpoint": f"layers.{layer}"
        if "llama" in model_name
        else f"layer_{layer}/width_131k/canonical",
    }

    rules = tri_color_countries_subset
    search_config = LogicConfig(
        search_concept_type="word", search_strategy="top_k", search_top_k=10
    )

    print(f"Loading ActivationReasoning model: {model_name}...")
    ar_model = ActivationReasoning(
        rules=rules,
        **model_config,
        config=search_config,
        cache_dir=f"output/cache/{model_name}_countries",
        verbose=False,
    )

    print("Performing one-time concept search on training data...")
    train_data = load_train_data()
    train_prompts = [q for q, l in train_data["train"]]

    torch.cuda.synchronize()
    start_time = time.time()
    ar_model.search(inputs=train_prompts, reset_cache=True, batch_size=20)
    torch.cuda.synchronize()
    total_setup_time = time.time() - start_time

    avg_setup_time = total_setup_time / len(train_prompts)

    logic_config_al = LogicConfig(steering_factor=0)
    if steering:
        if "llama" in model_name:
            logic_config_al = LogicConfig(steering_factor=0.4, steering_top_k_rule=10)
        elif "gemma" in model_name:
            factor = 0.8 if dataset_type == "explicit" else 0.75
            logic_config_al = LogicConfig(
                steering_factor=factor, steering_top_k_rule=10
            )

    ar_model.configure(config=logic_config_al)

    color_concept_dict = {}
    if "llama" in model_name:
        color_concept_dict = {
            "black": {
                "indices": [26080, 257901, 191610, 91171],
                "weights": [1.39, 1.17, 0.51, 0.30],
            },
            "blue": {"indices": [251272], "weights": [3.81]},
            "gold": {"indices": [37152, 24536, 1681], "weights": [0.58, 0.27, 0.27]},
            "green": {"indices": [38115, 166587], "weights": [2.51, 1.52]},
            "orange": {"indices": [250617, 91937], "weights": [0.60, 0.39]},
            "red": {"indices": [213660, 92980, 96689], "weights": [9.03, 1.79, 1.33]},
            "white": {"indices": [114310, 16107], "weights": [6.62, 2.47]},
            "yellow": {
                "indices": [237080, 178076, 26092],
                "weights": [1.62, 1.61, 0.76],
            },
        }
    elif "gemma-2-9b" in model_name:
        color_concept_dict = {
            "yellow": {"indices": [55891, 39669], "weights": [10.29, 4.20]}
        }

    for color, features in color_concept_dict.items():
        if color in ar_model._al_concepts.concept_dict:
            feat_len = 10
            indices = features["indices"] + [0] * (feat_len - len(features["indices"]))
            weights = features["weights"] + [0.0] * (
                feat_len - len(features["weights"])
            )
            ar_model._al_concepts.concept_dict[color] = {
                "indices": indices,
                "weights": weights,
            }

    print("Model loaded and configured. Starting batched timing...")

    torch.cuda.synchronize()
    start_time = time.time()
    with torch.no_grad():
        _ = ar_model.generate(
            test_prompts,
            model_hyp=model_hyp,
            verbose=False,
            return_meta_data=False,
            batch_size=BATCHSIZE,
        )
    torch.cuda.synchronize()
    total_time = time.time() - start_time

    avg_time = total_time / num_samples
    print(
        f"Total setup time for ActivationReasoning ({len(train_prompts)} samples): {total_setup_time:.4f} seconds"
    )
    print(
        f"Average setup time per sample (ActivationReasoning): {avg_setup_time:.4f} seconds"
    )
    print(
        f"Total time for ActivationReasoning ({num_samples} samples): {total_time:.4f} seconds"
    )
    print(f"Average time per sample (ActivationReasoning): {avg_time:.4f} seconds")

    print("Unloading AL model and clearing cache...")
    del ar_model
    torch.cuda.empty_cache()

    return avg_time, avg_setup_time


def time_standard_model(model_name, raw_prompts, model_hyp, num_samples, baseline_type):
    """
    Loads, times (in a single batch), and unloads the standard Transformers model.
    Uses MODIFIED prompts specific to the baseline type.
    """
    print(
        f"\n--- 🚀 Timing Standard Transformers Model ({model_name + '_' + baseline_type}) ---"
    )

    # --- Prompt Modification Logic ---
    print(
        f"Creating modified prompts for the {model_name + '_' + baseline_type} model..."
    )
    prefix = prompt_r2c_prefix if baseline_type == "default" else prompt_r2c_CoT_prefix
    if baseline_type == "default":
        modified_prompts = [
            prefix + q.replace(prompt_body_to_remove, "") for q in raw_prompts
        ]
    elif baseline_type == "cot":
        modified_prompts = [
            prefix
            + q.replace(prompt_body_to_remove, "").replace(
                prompt_cot_body_to_remove,
                "",
            )
            for q in raw_prompts
        ]
    else:
        raise ValueError(f"Unknown baseline type: {baseline_type}")

    # --- Load Standard Model ---
    print(f"Loading {model_name + '_' + baseline_type} model...")
    standard_tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
    standard_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    ).eval()

    if standard_tokenizer.pad_token is None:
        standard_tokenizer.pad_token = standard_tokenizer.eos_token
        standard_model.config.pad_token_id = standard_model.config.eos_token_id

    print("Model loaded.")
    print("Starting batched timing...")
    torch.cuda.synchronize()
    start_time = time.time()
    inputs = standard_tokenizer(modified_prompts, return_tensors="pt", padding=True).to(
        standard_model.device
    )

    with torch.no_grad():
        for i in tqdm(range(0, len(modified_prompts), BATCHSIZE)):
            batch_inputs = {
                k: v[i : i + BATCHSIZE] for k, v in inputs.items()
            }  # Slicing the batch
            _ = standard_model.generate(**batch_inputs, **model_hyp)
    torch.cuda.synchronize()
    total_time = time.time() - start_time

    avg_time = total_time / num_samples
    print(
        f"Total time for {model_name + '_' + baseline_type} ({num_samples} samples): {total_time:.4f} seconds"
    )
    print(
        f"Average time per sample ({model_name + '_' + baseline_type}): {avg_time:.4f} seconds"
    )

    print("Unloading standard model and clearing cache...")
    del standard_model
    del standard_tokenizer
    torch.cuda.empty_cache()

    return avg_time


def main():
    parser = argparse.ArgumentParser(
        description="Time and compare generation functions sequentially for the countries dataset."
    )
    parser.add_argument(
        "--dataset_type",
        type=str,
        default="explicit",
        choices=["explicit", "implicit"],
        help="The dataset to use for timing.",
    )
    parser.add_argument(
        "--num_samples",
        type=int,
        default=100,
        help="Number of samples to run the timing test on.",
    )
    args = parser.parse_args()

    model_name_map = {
        "gemma_base": "google/gemma-2-9b",
        "gemma_it": "google/gemma-2-9b-it",
        "llama_31_base": "meta-llama/Meta-Llama-3.1-8B",
        "llama_31_it": "meta-llama/Meta-Llama-3-8B-Instruct",
        "llama_R1": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
    }

    print("--- Script Configuration ---")
    print(f"Dataset: {args.dataset_type}")
    print(f"Samples: {args.num_samples}")
    print("--------------------------\n")

    print(f"Loading {args.dataset_type} dataset...")
    data = (
        load_train_data()
        if args.dataset_type == "explicit"
        else load_implicit_train_data()
    )
    test_prompts = [train[0] for train in data["test"]]
    num_samples = min(args.num_samples, len(test_prompts))
    test_prompts = test_prompts[:num_samples]
    if num_samples < args.num_samples:
        print(
            f"Warning: Requested {args.num_samples}, but dataset only has {num_samples}. Using {num_samples}."
        )

    model_hyp = {
        "do_sample": False,
        "temperature": None,
        "top_k": None,
        "top_p": None,
        "max_new_tokens": 5,
    }

    result_dict = []
    for i in range(3):
        print(f"\n--- 🔄 Repetition {i + 1}/3 ---")
        model_name = model_name_map["llama_31_base"]
        avg_al_time, avg_setup_time = time_ar_model(
            model_name,
            True,
            args.dataset_type,
            test_prompts,
            model_hyp,
            num_samples,
        )
        result_dict.append(
            {
                "model": model_name + "_al",
                "avg_time_per_sample": avg_al_time,
                "avg_setup_time_per_sample": avg_setup_time,
                "total_time_with_setup": avg_al_time * num_samples
                + avg_setup_time * len(data["train"]),
                "total_setup_time": avg_setup_time * len(data["train"]),
                "total_time": avg_al_time * num_samples,
                "runs": i + 1,
            }
        )
        model_name = model_name_map["llama_31_base"]
        avg_standard_time = time_standard_model(
            model_name, test_prompts, model_hyp, num_samples, "default"
        )
        result_dict.append(
            {
                "model": model_name + "_default",
                "avg_time_per_sample": avg_standard_time,
                "total_time": avg_standard_time * num_samples,
                "runs": i + 1,
            }
        )
        model_name = model_name_map["llama_31_it"]
        model_hyp["max_new_tokens"] = 100
        avg_standard_time = time_standard_model(
            model_name, test_prompts, model_hyp, num_samples, "cot"
        )
        result_dict.append(
            {
                "model": model_name + "_cot",
                "avg_time_per_sample": avg_standard_time,
                "total_time": avg_standard_time * num_samples,
                "runs": i + 1,
            }
        )

        model_name = model_name_map["llama_R1"]
        model_hyp["max_new_tokens"] = 1000
        # model_hyp = {
        #     "do_sample": True,
        #     "temperature": 0.6,
        #     "top_k": 1,
        #     "top_p": 1,
        #     "max_new_tokens": 1000,
        # }
        avg_standard_time = time_standard_model(
            model_name, test_prompts, model_hyp, num_samples, "cot"
        )
        result_dict.append(
            {
                "model": model_name + "_cot",
                "avg_time_per_sample": avg_standard_time,
                "total_time": avg_standard_time * num_samples,
                "runs": i + 1,
            }
        )
        print("\n--- 💾 Saving Results ---")
        result_dict_tmp = pd.DataFrame.from_dict(result_dict)
        result_dict_tmp.to_csv(
            f"generation_timing_r2c_results_r{i + 1}.csv", index=False
        )

    result_dict = pd.DataFrame.from_dict(result_dict)
    result_dict.to_csv("generation_timing_r2c_results.csv", index=False)
    result_dict = (
        result_dict.groupby("model").mean(numeric_only=True).reset_index().to_markdown()
    )

    print("\n--- 📊 Final Comparison ---")
    print(result_dict)
    print("---------------------------------")


if __name__ == "__main__":
    main()
