import argparse
import gc
import json
import os
import pickle
import re
from contextlib import contextmanager
from copy import deepcopy
from typing import List, Optional

import dotenv
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

torch._dynamo.config.cache_size_limit = 1024

# Constants
STEERING_MULTIPLIER = 1.5
MAX_NEW_TOKENS = 25
TEMPERATURE = 0.7
DO_SAMPLE = True
TOKEN_POS_TO_STEER = -1
NUM_PROMPTS_FOR_STEERING_VECTOR = 85
LAYER_FRACTION = 0.8
# BATCH_SIZE = 200
BATCH_SIZE = 1000
GPU_KL_DIVERGENCE = True

# Global variables for steering
steering_vector_internals = None
steering_multiplier_internal = STEERING_MULTIPLIER
activation_storage = {}

# Rhyme family words
rhyme_family_words = {
    "ing": set(
        [
            "bring",
            "cling",
            "ding",
            "fling",
            "king",
            "ping",
            "ring",
            "sing",
            "sling",
            "spring",
            "sting",
            "string",
            "swing",
            "thing",
            "wing",
            "zing",
            "bling",
            "wring",
            "sping",
            "ming",
            "ting",
        ]
    ),
    "air": set(
        [
            "aer",
            "affair",
            "air",
            "aware",
            "bare",
            "bear",
            "blare",
            "brer",
            "care",
            "chair",
            "compare",
            "crare",
            "dare",
            "declare",
            "despair",
            "e'er",
            "ere",
            "everywhere",
            "fair",
            "fare",
            "fayre",
            "flair",
            "flare",
            "glair",
            "glare",
            "hair",
            "hare",
            "lair",
            "layer",
            "mare",
            "mayor",
            "ne'er",
            "nightmare",
            "pair",
            "pare",
            "prayer",
            "rare",
            "repair",
            "scare",
            "share",
            "snare",
            "spare",
            "square",
            "stair",
            "stare",
            "swear",
            "tare",
            "tear",
            "terror",
            "their",
            "there",
            "they're",
            "vair",
            "ware",
            "wear",
            "where",
            "stair",
            "prayer",
            "there",
            "their",
            "unaware",
            "unfair",
            "yair",
            "yare",
            "pear",
            "heir",
            "err",
        ]
    ),
    "ip": set(
        [
            "blip",
            "chip",
            "clip",
            "crip",
            "dip",
            "drip",
            "equip",
            "fingertip",
            "flip",
            "grip",
            "gyp",
            "hip",
            "kip",
            "lip",
            "nip",
            "pip",
            "plip",
            "quip",
            "rip",
            "scrip",
            "ship",
            "sip",
            "skip",
            "slip",
            "snip",
            "strip",
            "tip",
            "trip",
            "whip",
            "yip",
            "zip",
            "equip",
            "flip",
            "gyp",
            "script",
            "yip",
        ]
    ),
    "oat": set(
        [
            "afloat",
            "antidote",
            "bloat",
            "boat",
            "coat",
            "dote",
            "float",
            "gloat",
            "goat",
            "groat",
            "haute",
            "hote",
            "lote",
            "moat",
            "mote",
            "moth",
            "note",
            "oat",
            "quote",
            "rote",
            "scrote",
            "shoat",
            "smote",
            "stoat",
            "throat",
            "tote",
            "troat",
            "vote",
            "wrote",
            "yote",
            "bloat",
            "gloat",
            "promote",
            "remote",
            "denote",
            "devote",
            "smote",
            "mote",
        ]
    ),
    "ird": set(
        [
            "altered",
            "assured",
            "bird",
            "blerd",
            "conferred",
            "curd",
            "cured",
            "declared",
            "endured",
            "engendered",
            "ensured",
            "ferd",
            "flerd",
            "furred",
            "gird",
            "herd",
            "hindered",
            "Kurd",
            "merd",
            "nerd",
            "sherd",
            "slurred",
            "spurred",
            "stirred",
            "surd",
            "third",
            "word",
            "absurd",
            "blurred",
            "deferred",
            "deterred",
            "incurred",
            "inferred",
            "nurtured",
            "obscured",
            "occurred",
            "pondered",
            "preferred",
            "procured",
            "remembered",
            "referred",
            "transferred",
            "concurred",
            "heard",
            "turd",
            "un-stirred",
            "unblurred",
            "undeterred",
            "unheard",
            "well-stirred",
        ]
    ),
    "ee": set(
        [
            "'e",
            "ae",
            "B",
            "bee",
            "brae",
            "brie",
            "C",
            "chi",
            "cree",
            "D",
            "dee",
            "E",
            "ea",
            "fee",
            "flea",
            "flee",
            "free",
            "G",
            "gi",
            "gie",
            "glee",
            "gree",
            "he",
            "He",
            "key",
            "knee",
            "lea",
            "lee",
            "ley",
            "me",
            "oui",
            "P",
            "pea",
            "phi",
            "plea",
            "quay",
            "re",
            "schwi",
            "sea",
            "see",
            "she",
            "schchi",
            "ski",
            "spree",
            "T",
            "te",
            "tea",
            "tee",
            "thee",
            "three",
            "tree",
            "trie",
            "ti",
            "V",
            "we",
            "wee",
            "Wii",
            "xe",
            "ye",
            "Z",
            "ze",
            "zhe",
            "agree",
            "decree",
            "degree",
            "foresee",
            "trainee",
            "trustee",
            "jubilee",
            "be",
            "gee",
            "pee",
        ]
    ),
    "ight": set(
        [
            "bight",
            "bite",
            "blight",
            "bright",
            "byte",
            "cite",
            "delight",
            "dight",
            "dright",
            "drite",
            "fight",
            "flight",
            "flite",
            "fright",
            "gruit",
            "height",
            "hight",
            "ight",
            "kite",
            "knight",
            "light",
            "lite",
            "might",
            "mite",
            "moonlight",
            "night",
            "pight",
            "plight",
            "quite",
            "right",
            "rite",
            "shite",
            "sight",
            "site",
            "sleight",
            "slight",
            "slite",
            "smite",
            "spite",
            "spright",
            "sprite",
            "sunlight",
            "thwite",
            "tight",
            "tonight",
            "trite",
            "twite",
            "white",
            "wight",
            "wite",
            "wright",
            "write",
            "alight",
            "contrite",
            "delight",
            "excite",
            "ignite",
            "incite",
            "indite",
            "invite",
            "polite",
            "recite",
            "unite",
            "smite",
        ]
    ),
    "ake": set(
        [
            "ache",
            "bake",
            "blake",
            "brake",
            "break",
            "cake",
            "crake",
            "drake",
            "fake",
            "flake",
            "hake",
            "jake",
            "laik",
            "lake",
            "make",
            "nake",
            "naik",
            "pake",
            "quake",
            "raik",
            "rake",
            "remake",
            "sake",
            "shake",
            "sheik",
            "sheikh",
            "slake",
            "smaik",
            "snake",
            "spake",
            "splake",
            "stake",
            "steak",
            "straik",
            "strake",
            "wake",
            "take",
            "traik",
            "vraic",
            "awake",
            "betake",
            "forsake",
            "mistake",
            "partake",
            "retake",
            "opaque",
        ]
    ),
    "ow": set(
        [
            "'fro",
            "afterglow",
            "aglow",
            "ago",
            "beau",
            "blow",
            "bow",
            "bro",
            "cro",
            "crow",
            "doe",
            "dough",
            "dso",
            "dzo",
            "eau",
            "eaux",
            "faux",
            "floe",
            "flow",
            "foe",
            "fro",
            "glow",
            "go",
            "gro",
            "grow",
            "ho",
            "hoe",
            "joe",
            "know",
            "kou",
            "lo",
            "low",
            "mo",
            "moe",
            "mow",
            "no",
            "noh",
            "O",
            "oh",
            "owe",
            "po",
            "pro",
            "rho",
            "roe",
            "row",
            "schmo",
            "schmoe",
            "sew",
            "shew",
            "show",
            "sloe",
            "slow",
            "snow",
            "so",
            "sow",
            "stow",
            "tho",
            "tho'",
            "though",
            "throe",
            "throw",
            "toe",
            "tow",
            "voe",
            "whoa",
            "woe",
            "wough",
            "yeo",
            "yo",
            "yoe",
            "zho",
            "zo",
            "bestow",
            "below",
            "elbow",
            "fellow",
            "follow",
            "hollow",
            "mellow",
            "narrow",
            "shadow",
            "shallow",
            "window",
            "winnow",
            "yellow",
            "borrow",
            "sorrow",
            "tomorrow",
        ]
    ),
    "it": set(
        [
            "it",
            "bit",
            "blit",
            "brit",
            "Brit",
            "chit",
            "cit",
            "clit",
            "crit",
            "dit",
            "exhibit",
            "exquisite",
            "fit",
            "flit",
            "frit",
            "ghit",
            "git",
            "gnit",
            "grit",
            "hit",
            "infinite",
            "jit",
            "kit",
            "knit",
            "lit",
            "mitt",
            "nit",
            "pit",
            "quit",
            "rit",
            "shit",
            "sit",
            "skit",
            "slit",
            "smit",
            "snit",
            "spit",
            "split",
            "sprit",
            "squit",
            "tit",
            "twit",
            "whit",
            "wit",
            "writ",
            "zit",
            "admit",
            "commit",
            "emit",
            "habit",
            "hermit",
            "omit",
            "permit",
            "rabbit",
            "remit",
            "submit",
            "transmit",
            "grit",
        ]
    ),
}

RHYME_FAMILY_PAIRS = [
    ("ing", "air"),
    ("ing", "ip"),
    ("air", "ip"),
    ("air", "oat"),
    ("ip", "oat"),
    ("ip", "ird"),
    ("oat", "ird"),
    ("oat", "ee"),
    ("ird", "ee"),
    ("ird", "ight"),
    ("ee", "ight"),
    ("ee", "ake"),
    ("ight", "ake"),
    ("ight", "ow"),
    ("ake", "ow"),
    ("ake", "it"),
    ("ow", "it"),
    ("ow", "ing"),
    ("it", "ing"),
    ("it", "air"),
]

# RHYME_FAMILY_PAIRS = [("ight", "ow"), ("ee", "ake")]
# BATCH_SIZE = 100
# LAYER_FRACTION = 0.1


WORD_PAIRS_SAME_RHYME_FAMILY = [
    ("king", "ring"),
    ("bear", "chair"),
    ("ship", "chip"),
    ("goat", "boat"),
    ("bird", "word"),
    ("bee", "tree"),
    ("light", "night"),
    ("snake", "rake"),
    ("snow", "crow"),
    ("rabbit", "habit"),
]

WORD_PAIRS_DIFFERENT_RHYME_FAMILY = [
    ("king", "bear"),
    ("chair", "ship"),
    ("ship", "boat"),
    ("goat", "bird"),
    ("bird", "bee"),
    ("tree", "light"),
    ("night", "snake"),
    ("snake", "crow"),
    ("crow", "rabbit"),
    ("rabbit", "king"),
]

SPECIFIC_WORD_PAIRS = WORD_PAIRS_SAME_RHYME_FAMILY + WORD_PAIRS_DIFFERENT_RHYME_FAMILY

# Add missing results tracking
results_dict = {}


def add_to_results(
    model_name,
    rhyme_family1,
    rhyme_family2,
    layer,
    steering_multiplier,
    prompt,
    last_word_correct_unsteered_rhyme_family1,
    last_word_correct_steered_rhyme_family1,
    last_word_correct_unsteered_rhyme_family2,
    last_word_correct_steered_rhyme_family2,
    last_word_regeneration_unsteered_rhyme_family1,
    last_word_regeneration_steered_rhyme_family1,
    last_word_regeneration_unsteered_rhyme_family2,
    last_word_regeneration_steered_rhyme_family2,
    top_1_difference,
    idx_of_first_token_difference,
    high_kl_divergences,
    idx_of_first_high_kl_difference,
):
    """Add results to the global results dictionary."""
    global results_dict

    key = f"{model_name}_{rhyme_family1}_{rhyme_family2}_{layer}_{len(results_dict)}"

    results_dict[key] = {
        "model_name": model_name,
        "rhyme_family1": rhyme_family1,
        "rhyme_family2": rhyme_family2,
        "layer": layer,
        "steering_multiplier": steering_multiplier,
        "prompt": prompt,
        "last_word_fraction_unsteered_rhyme_family1": last_word_correct_unsteered_rhyme_family1,
        "last_word_fraction_steered_rhyme_family1": last_word_correct_steered_rhyme_family1,
        "last_word_fraction_unsteered_rhyme_family2": last_word_correct_unsteered_rhyme_family2,
        "last_word_fraction_steered_rhyme_family2": last_word_correct_steered_rhyme_family2,
        "last_word_regeneration_unsteered_rhyme_family1": last_word_regeneration_unsteered_rhyme_family1,
        "last_word_regeneration_steered_rhyme_family1": last_word_regeneration_steered_rhyme_family1,
        "last_word_regeneration_unsteered_rhyme_family2": last_word_regeneration_unsteered_rhyme_family2,
        "last_word_regeneration_steered_rhyme_family2": last_word_regeneration_steered_rhyme_family2,
        "fraction_top_1_difference": top_1_difference,
        "avg_idx_of_first_token_difference": idx_of_first_token_difference,
        "fraction_high_kl_difference": high_kl_divergences,
        "avg_idx_of_first_high_kl_difference": idx_of_first_high_kl_difference,
    }


def get_script_dir():
    """Get the directory containing the rhyme steering stages."""
    return os.path.dirname(os.path.abspath(__file__))


def load_rhyme_family_data():
    """Load rhyme family lines data."""
    script_dir = get_script_dir()
    parent_dir = os.path.dirname(script_dir)

    rhyme_family_lines_path = os.path.join(parent_dir, "rhyme_family_lines.json")
    rhyme_family_lines_word_path = os.path.join(parent_dir, "rhyme_family_words.json")

    with open(rhyme_family_lines_path, "r") as f:
        rhyme_family_lines = json.load(f)

    with open(rhyme_family_lines_word_path, "r") as f:
        rhyme_family_lines_word = json.load(f)

    return rhyme_family_lines, rhyme_family_lines_word


def load_prompts(mode, test_or_train, rhyme_family, model_name=None, strip=False):
    script_dir = get_script_dir()
    parent_dir = os.path.dirname(script_dir)
    endcharacter = "\n"
    if mode == "rhyme_family_steering":
        with open(
            os.path.join(script_dir, "data", test_or_train, "rhyme_family_lines.json")
        ) as f:
            lines = json.load(f)[rhyme_family]
    elif mode == "specific_word_steering":
        if model_name is None:
            with open(
                os.path.join(
                    script_dir, "data", test_or_train, "specific_word_lines.json"
                )
            ) as f:
                lines = json.load(f)[rhyme_family]
        else:
            with open(
                os.path.join(
                    script_dir,
                    "data",
                    test_or_train,
                    model_name,
                    "specific_word_lines_suggestive.json",
                )
            ) as f:
                lines = json.load(f)[rhyme_family]
    if strip:
        lines = [line.strip(endcharacter) for line in lines]
    prompts = [f"A rhyming couplet:\n{line}" for line in lines]
    return prompts


def get_model(model_name: str):
    """Load model and tokenizer."""
    model_mapping = {
        "Gemma2_2B": "google/gemma-2-2b-it",
        "Gemma2_9B": "google/gemma-2-9b-it",
        "Gemma2_27B": "google/gemma-2-27b-it",
        "Gemma3_1B": "google/gemma-3-1b-it",
        "Gemma3_4B": "google/gemma-3-4b-it",
        "Gemma3_12B": "google/gemma-3-12b-it",
        "Gemma3_27B": "google/gemma-3-27b-it",
        "Llama3.2_3B": "meta-llama/Llama-3.2-3B-Instruct",
        "Llama3.1_8B": "meta-llama/Llama-3.1-8B-Instruct",
        "Llama3.3_70B": "meta-llama/Llama-3.3-70B-Instruct",
        "Qwen3_8B": "Qwen/Qwen3-8B",
        "Qwen3_14B": "Qwen/Qwen3-14B",
        "Qwen3_32B": "Qwen/Qwen3-32B",
        "Gemma2_2B_Base": "google/gemma-2-2b",
        "Gemma2_9B_Base": "google/gemma-2-9b",
        "Gemma2_27B_Base": "google/gemma-2-27b",
        "Gemma3_1B_Base": "google/gemma-3-1b-pt",
        "Gemma3_4B_Base": "google/gemma-3-4b-pt",
        "Gemma3_12B_Base": "google/gemma-3-12b-pt",
        "Gemma3_27B_Base": "google/gemma-3-27b-pt",
        "Llama3.2_3B_Base": "meta-llama/Llama-3.2-3B",
        "Llama3.1_8B_Base": "meta-llama/Llama-3.1-8B",
        "Llama3.3_70B_Base": "meta-llama/Llama-3.3-70B",
        "Qwen3_8B_Base": "Qwen/Qwen3-8B-Base",
        "Qwen3_14B_Base": "Qwen/Qwen3-14B-Base",
        "Qwen3_32B_Base": "Qwen/Qwen3-32B-Base",
    }

    if model_name not in model_mapping:
        raise ValueError(f"Model name {model_name} not supported")

    MODEL_ID = model_mapping[model_name]

    # Determine device and dtype
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dtype = (
        torch.bfloat16
        if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
        else torch.float32
    )

    print(f"Loading model: {MODEL_ID}")
    # print(f"Loading model: {model_name}")
    print(f"Using device: {device}")
    print(f"Using dtype: {dtype}")

    dotenv.load_dotenv(".env")
    from huggingface_hub import login

    HUGGING_FACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
    login(token=HUGGING_FACE_TOKEN)
    print("Hugging Face login successful (using provided token).")

    # Load Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)  # MODEL_ID
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    if tokenizer.padding_side == "right":
        print("Changing padding from right to left")
        tokenizer.padding_side = "left"

    # Load Model
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,  # MODEL_ID
        torch_dtype=dtype,
        quantization_config=None,
        device_map="auto",
        trust_remote_code=True,
    )

    # if "gemma-3" in model_name and "27" in model_name:
    #     model = model.language_model

    # if "gemma-3" in model_name and "1" in model_name:
    #     model = model[0]

    # if hasattr(model, "language_model"):
    #     model = model.language_model

    model.to("cuda")
    return model, tokenizer

def get_model_config(model):
    if hasattr(model, "language_model"):
        return model.language_model.config
    else:
        return model.config

def get_model_layer_name(model, layer):
    if hasattr(model, "language_model"):
        return f"language_model.layers.{layer}"
    else:
        return f"model.layers.{layer}"
    


def cleanup_gpu_memory():
    """Clean up GPU memory."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        torch.cuda.ipc_collect()
        torch.cuda.empty_cache()


def get_module_by_name(model, module_name):
    """Helper function to get a module object from its name string."""
    names = module_name.split(".")
    module = model
    for name in names:
        module = getattr(module, name)
    return module


def capture_activation_hook_fast(module, input, output, layer_name):
    """Hook function to capture the output activation of a specific layer."""
    global activation_storage
    if isinstance(output, torch.Tensor):
        activation_storage[layer_name] = output[:, -1, :].detach().cpu()
    elif isinstance(output, tuple):
        activation_storage[layer_name] = output[0][:, -1, :].detach().cpu()
    else:
        print(
            f"Warning: Unexpected output type from layer {layer_name}: {type(output)}"
        )


def get_activations_fast(
    model, tokenizer, prompts: List[str], layer_name: str
) -> Optional[torch.Tensor]:
    """Runs prompts through the model and captures activations from the target layer."""
    global activation_storage
    activation_storage = {}

    target_module = get_module_by_name(model, layer_name)
    hook_handle = target_module.register_forward_hook(
        lambda module, input, output: capture_activation_hook_fast(
            module, input, output, layer_name
        )
    )

    with torch.no_grad():
        inputs = tokenizer(
            prompts, return_tensors="pt", padding=True, truncation=True
        ).to(model.device)
        _ = model(**inputs)

        if layer_name in activation_storage:
            last_token_activations = activation_storage[layer_name]
            del activation_storage[layer_name]
        else:
            print(
                f"Warning: Activation for layer {layer_name} not captured for prompts"
            )
            return None

    hook_handle.remove()

    avg_activation = last_token_activations.mean(dim=0).squeeze()
    return avg_activation


def get_steering_vector_fast(
    model, tokenizer, negative_prompts, positive_prompts, layer=20
):
    """Compute the steering vector for a given layer."""
    target_layer_name = get_model_layer_name(model, layer)

    avg_pos_activation = get_activations_fast(
        model, tokenizer, positive_prompts, target_layer_name
    )
    avg_neg_activation = get_activations_fast(
        model, tokenizer, negative_prompts, target_layer_name
    )

    steering_vector = None
    if avg_pos_activation is not None and avg_neg_activation is not None:
        steering_vector = avg_pos_activation - avg_neg_activation
    else:
        print("Error: Could not compute steering vector due to missing activations.")

    del avg_pos_activation, avg_neg_activation
    cleanup_gpu_memory()
    return steering_vector


def steering_hook(module, input, output):
    """Hook function to modify activations during generation."""
    global steering_vector_internals, steering_multiplier_internal
    if steering_vector_internals is not None:
        if isinstance(output, torch.Tensor):
            if output.shape[1] != 1:
                batch_size = output.shape[0]
                batch_size_per_steering_vector = batch_size // len(
                    steering_vector_internals
                )
                """
                for i, steering_vector in enumerate(steering_vector_internals):
                    output[
                        torch.arange(
                            i * batch_size_per_steering_vector,
                            (i + 1) * batch_size_per_steering_vector,
                        ),
                        TOKEN_POS_TO_STEER,
                        :,
                    ] += (
                        steering_vector.to(output.device, dtype=output.dtype)
                        * steering_multiplier_internal
                    )
                """
                output[torch.arange(batch_size), TOKEN_POS_TO_STEER, :] += (
                    steering_vector_internals * steering_multiplier_internal
                )

            return output
        elif isinstance(output, tuple):
            modified_tensor = output[0]
            if modified_tensor.shape[1] != 1:
                batch_size = modified_tensor.shape[0]
                batch_size_per_steering_vector = batch_size // len(
                    steering_vector_internals
                )
                """
                for i, steering_vector in enumerate(steering_vector_internals):
                    modified_tensor[
                        torch.arange(
                            i * batch_size_per_steering_vector,
                            (i + 1) * batch_size_per_steering_vector,
                        ),
                        TOKEN_POS_TO_STEER,
                        :,
                    ] += (
                        steering_vector.to(
                            output[0].device, dtype=output[0].dtype
                        )
                        * steering_multiplier_internal
                    )
                """
                modified_tensor[torch.arange(batch_size), TOKEN_POS_TO_STEER, :] += (
                    steering_vector_internals * steering_multiplier_internal
                )
            return (modified_tensor,) + output[1:]
        else:
            print(
                f"Warning: Steering hook encountered unexpected output type: {type(output)}"
            )
            return output
    return output


@contextmanager
def apply_steering(model, layer, steering_vectors, multiplier):
    """Context manager to temporarily apply the steering hook."""
    global steering_vector_internals, steering_multiplier_internal
    layer_name = get_model_layer_name(model, layer)

    handle = None
    try:
        steering_vector_internals = steering_vectors
        steering_multiplier_internal = multiplier
        target_module = get_module_by_name(model, layer_name)
        handle = target_module.register_forward_hook(steering_hook)
        yield
    finally:
        if handle:
            handle.remove()
        steering_vector_internals = None
        steering_multiplier_internal = 1.0
        cleanup_gpu_memory()


def generate_steered_output(
    steering_vectors,
    model,
    tokenizer,
    generation_prompts,
    batch_size_per_prompt,
    num_prompts_per_rollout,
    layer=20,
    steering_multiplier=STEERING_MULTIPLIER,
    return_type="text",
    max_new_tokens=MAX_NEW_TOKENS,
    token_to_steer=TOKEN_POS_TO_STEER,
    num_prompts_per_steering_vector=None,
):
    """Generate text with optional steering."""
    if num_prompts_per_steering_vector is None:
        num_prompts_per_steering_vector = len(generation_prompts)
    if type(steering_vectors) is not list and steering_vectors is not None:
        steering_vectors = [steering_vectors]

    global TOKEN_POS_TO_STEER
    prev_token_to_steer = TOKEN_POS_TO_STEER
    TOKEN_POS_TO_STEER = token_to_steer
    input_prompts = []
    for prompt in generation_prompts:
        input_prompts.extend([prompt] * batch_size_per_prompt)

    inputs = tokenizer(input_prompts, return_tensors="pt", padding=True).to(
        model.device
    )
    text = []

    config = get_model_config(model)

    hidden_states_all = torch.zeros(
        (0, inputs.input_ids.shape[1], config.hidden_size), device=model.device
    )

    all_in_one = num_prompts_per_rollout == len(generation_prompts)

    for i in range(0, len(generation_prompts), num_prompts_per_rollout):
        lower_bound = i * batch_size_per_prompt
        upper_bound = (
            min(i + num_prompts_per_rollout, len(generation_prompts))
            * batch_size_per_prompt
        )
        inputs_rollout = deepcopy(inputs)
        inputs_rollout.input_ids = inputs_rollout.input_ids[lower_bound:upper_bound]
        inputs_rollout.attention_mask = inputs_rollout.attention_mask[
            lower_bound:upper_bound
        ]
        if steering_vectors is not None:
            lower_bound_steering_vectors = i // num_prompts_per_steering_vector
            upper_bound_steering_vectors = (
                min(i + num_prompts_per_rollout, len(generation_prompts))
                // num_prompts_per_steering_vector
            )
            steering_vectors_rollout = torch.cat(
                [
                    torch.stack(
                        [steering_vector]
                        * num_prompts_per_steering_vector
                        * batch_size_per_prompt,
                        dim=0,
                    )
                    for steering_vector in steering_vectors[
                        lower_bound_steering_vectors:upper_bound_steering_vectors
                    ]
                ],
                dim=0,
            ).to(model.device, model.dtype)

        if return_type == "text":
            if steering_vectors is None:
                with torch.no_grad():
                    outputs = model.generate(
                        **inputs_rollout,
                        max_new_tokens=max_new_tokens,
                        temperature=TEMPERATURE,
                        do_sample=DO_SAMPLE,
                        pad_token_id=tokenizer.eos_token_id,
                    )
            else:
                with torch.no_grad():
                    with apply_steering(
                        model,
                        layer,
                        steering_vectors_rollout,
                        steering_multiplier,
                    ):
                        outputs = model.generate(
                            **inputs_rollout,
                            max_new_tokens=max_new_tokens,
                            temperature=TEMPERATURE,
                            do_sample=DO_SAMPLE,
                            pad_token_id=tokenizer.eos_token_id,
                        )
            text_rollout = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            text.extend(text_rollout)

        elif return_type == "last_hidden_state":
            if steering_vectors is None:
                with torch.no_grad():
                    outputs = model.model(**inputs_rollout)
            else:
                with torch.no_grad():
                    with apply_steering(
                        model, layer, steering_vectors_rollout, steering_multiplier
                    ):
                        outputs = model.model(**inputs_rollout)
            if all_in_one:
                hidden_states_all = outputs.last_hidden_state
            else:
                hidden_states_all = torch.cat(
                    [hidden_states_all, outputs.last_hidden_state], dim=0
                )
            # logits = outputs.logits[:, :, :]
            # probs = torch.nn.functional.softmax(logits, dim=-1)
            # probs_float = probs.to(torch.float32)
            # all_probs = probs_float.numpy()
            # if keep_on_gpu:
            #     probs_list = torch.cat([probs_list, probs], dim=0)
            # else:
            #     probs_list = np.concatenate(
            #         [probs_list, probs.cpu().to(torch.float32).numpy()], axis=0
            #     )
            # del probs, logits, outputs

    del inputs, inputs_rollout
    cleanup_gpu_memory()
    TOKEN_POS_TO_STEER = prev_token_to_steer

    if return_type == "text":
        return text
    elif return_type == "last_hidden_state":
        return hidden_states_all.to(dtype=model.dtype)
    else:
        raise ValueError(f"Invalid return type: {return_type}")


def get_batch_size_fitting_in_memory(
    model, tokenizer, generation_prompts, max_tokens_to_generate
):
    """Calculate batch size that fits in GPU memory."""
    seq_len = (
        max([len(tokenizer.encode(prompt)) for prompt in generation_prompts])
        + max_tokens_to_generate
    )
    config = get_model_config(model)

    n_layers = config.num_hidden_layers
    d_head = config.head_dim
    n_heads = config.num_attention_heads
    d_vocab = config.vocab_size
    required_bytes = 2 * seq_len * (n_layers * n_heads * d_head * 3 + d_vocab)
    memory_left = torch.cuda.get_device_properties(
        0
    ).total_memory - torch.cuda.memory_allocated(0)
    return int(memory_left / required_bytes) * 0.9


# Text processing functions
def remove_non_alphanumeric_characters_from_right(text):
    last_idx = len(text)
    while last_idx > 0 and not text[last_idx - 1].isalnum():
        last_idx -= 1
    return text[:last_idx]


def remove_non_alphanumeric_characters_from_left(text):
    first_idx = 0
    while first_idx < len(text) and not text[first_idx].isalnum():
        first_idx += 1
    return text[first_idx:]


def get_cleaned_up_text(text, num_examples=0, num_lines=3):
    test = re.sub(r"\n+", "\n", text)
    if num_examples > 0:
        if "\n\n" not in text:
            return text
        # return re.sub(r'\n+', '\n', text)#
        return get_cleaned_up_text(
            text.replace("\n\n", "\n"),
            text.split("\n\n")[num_examples],
            num_examples=0,
            num_lines=num_lines,
        )
    lines = text.split("\n")
    if len(lines) < num_lines:
        return text
    cleaned_up_text = "\n".join(lines[:num_lines])
    cleaned_up_text = remove_non_alphanumeric_characters_from_right(cleaned_up_text)
    if cleaned_up_text.count("\n") <= 1:
        cleaned_up_text += "\n"
    return cleaned_up_text


# Metric calculation functions
def get_word_correct(last_words, rhyme_family: str):
    if rhyme_family in rhyme_family_words:
        words = rhyme_family_words[rhyme_family]
    else:
        # This case is for single words e.g. "king"
        words = [rhyme_family]
    correct = np.array([w.lower() in words for w in last_words])
    if rhyme_family == "ing":
        correct |= np.array([w.endswith("ing") for w in last_words])
    if rhyme_family == "air":
        correct |= np.array([w.endswith("where") for w in last_words])
    if rhyme_family == "ee":
        correct |= np.array([w.endswith("y") or w.endswith("ee") for w in last_words])
    return correct.astype(float)


def get_last_word_correct(texts, rhyme_families, num_words=1):
    last_words = [" ".join(text.split(" ")[-num_words:]) for text in texts]
    return [
        get_word_correct(last_words, rhyme_family) for rhyme_family in rhyme_families
    ]


def get_last_word_regeneration_correct(
    model, tokenizer, texts, rhyme_families: List[str], batch_size_small: int
):
    last_lines = [text.split("\n")[-1] for text in texts]
    last_lines_without_last_word = [line.rsplit(" ", 1)[0] for line in last_lines]
    last_line_without_last_word_prompts = [
        "A short sentence:\n" + line for line in last_lines_without_last_word
    ]

    regenerated_last_lines = generate_steered_output(
        None,
        model,
        tokenizer,
        last_line_without_last_word_prompts,
        batch_size_per_prompt=1,
        num_prompts_per_rollout=batch_size_small,
        max_new_tokens=3,
    )

    generated_last_words = [
        remove_non_alphanumeric_characters_from_left(
            regenerated_last_line[len(last_line_without_last_word_prompt) :]
            + " wrongwongwongxxx"
        ).split()[0]
        for regenerated_last_line, last_line_without_last_word_prompt in zip(
            regenerated_last_lines, last_line_without_last_word_prompts
        )
    ]
    generated_last_words = [
        remove_non_alphanumeric_characters_from_right(word)
        for word in generated_last_words
    ]
    last_word_corrects_list = [
        get_word_correct(generated_last_words, rhyme_family)
        for rhyme_family in rhyme_families
    ]

    return last_word_corrects_list


def calculate_kl_divergence(p, q, use_gpu=GPU_KL_DIVERGENCE):
    """Calculate KL divergence between two probability distributions."""
    # if use_gpu:
    #     return calculate_kl_divergence_GPU(p, q)
    # else:
    #     return calculate_kl_divergence_gpu_chunked(p, q)
    # return np.sum(np.where(p != 0, p * np.log(p / q), 0), axis=-1)
    # return calculate_kl_divergence_gpu_chunked(p, q)
    return calculate_kl_divergence_GPU(p, q)


def calculate_kl_divergence_GPU(p, q):
    """Calculate KL divergence between two probability distributions on GPU."""
    return (
        torch.sum(torch.where(p != 0, p * torch.log(p / q), 0), dim=-1)
        .to(torch.float32)
        .cpu()
        .numpy()
    )


def calculate_kl_divergence_gpu_chunked(p, q, chunk_size=20, device="cuda"):
    """Calculate KL divergence on GPU with chunked processing to manage memory."""
    if not torch.cuda.is_available():
        print("CUDA not available, falling back to CPU")
        return calculate_kl_divergence(p, q)

    # Convert to torch tensors if needed
    if not isinstance(p, torch.Tensor):
        p = torch.from_numpy(p).float()
    if not isinstance(q, torch.Tensor):
        q = torch.from_numpy(q).float()

    results = []

    for i in range(0, p.shape[0], chunk_size):
        # Process chunk on GPU
        p_chunk = p[i : i + chunk_size]  # .to(device)
        q_chunk = q[i : i + chunk_size]  # .to(device)

        # Calculate KL divergence
        with torch.no_grad():  # Save memory by not tracking gradients
            kl_chunk = torch.sum(
                torch.where(p_chunk != 0, p_chunk * torch.log(p_chunk / q_chunk), 0),
                dim=-1,
            )

        # Move back to CPU and store
        # results.append(kl_chunk.cpu())
        results.append(kl_chunk)

        # Clear GPU cache
        del p_chunk, q_chunk, kl_chunk
        torch.cuda.empty_cache()

    return torch.cat(results).cpu().numpy()


def get_nth_true_per_batch(arr_2d, n, reversed=False):
    """Get the nth True index for each batch (returns array of indices)"""
    batch_size = arr_2d.shape[0]
    result = np.full(batch_size, -1)  # -1 indicates "not found"

    for i in range(batch_size):
        true_indices = np.where(arr_2d[i])[0]
        if len(true_indices) >= n:
            if reversed:
                result[i] = true_indices[-(n)]
            else:
                result[i] = true_indices[n - 1]
        else:
            result[i] = arr_2d.shape[1] - 1
            # raise ValueError(f"Not enough true values in batch {i}")

    return result


def get_min_and_max_idxs(texts, tokenizer, end_character="\n"):
    token_idxs = tokenizer(
        texts, return_tensors="pt", padding=True, truncation=True
    ).input_ids

    if tokenizer.padding_side == "left":
        max_idxs = torch.full(
            (len(texts),), token_idxs.shape[1] - 1, dtype=torch.long
        ).numpy()
    else:
        max_idxs = (
            ((token_idxs == tokenizer.pad_token_id).to(torch.int).argmax(dim=-1) - 1)
            .cpu()
            .numpy()
        )
        max_idxs[max_idxs == -1] = token_idxs.shape[1] - 1

    newline_tokens = tokenizer.encode(end_character, add_special_tokens=False)
    new_line_token_idx = newline_tokens[0]

    min_idxs = get_nth_true_per_batch(
        token_idxs == new_line_token_idx, 1, reversed=True
    )
    return min_idxs, max_idxs


def get_kl_above_threshold(kl_divergences, threshold=1.0):
    kl_above_threshold = (kl_divergences > threshold).astype(float)
    return kl_above_threshold


def get_top_1_difference(
    probs_with_steering, probs_no_steering, use_gpu=GPU_KL_DIVERGENCE
):
    if use_gpu:
        top_1_differences = get_top_1_difference_GPU(
            probs_with_steering, probs_no_steering
        )
    else:
        top_1_differences = (
            probs_with_steering.argmax(axis=-1) != probs_no_steering.argmax(axis=-1)
        ).astype(float)
    return top_1_differences


def get_top_1_difference_GPU(probs_with_steering, probs_no_steering):
    top_1_differences = (
        (probs_with_steering.argmax(dim=-1) != probs_no_steering.argmax(dim=-1))
        .cpu()
        .numpy()
    )
    return top_1_differences


def get_idx_of_first_top_1_difference(top_1_differences, max_idxs):
    first_top_1_difference_idxs = np.argmax(top_1_differences, axis=-1).astype(float)
    batch_size = top_1_differences.shape[0]
    for i in range(batch_size):
        if first_top_1_difference_idxs[i] == 0:
            first_top_1_difference_idxs[i] = max_idxs[i]
    return first_top_1_difference_idxs


def get_idx_of_first_high_kl_divergence(high_kl_divergences, max_idxs):
    first_high_kl_divergence_idxs = np.argmax(high_kl_divergences, axis=-1).astype(
        float
    )
    batch_size = high_kl_divergences.shape[0]
    for i in range(batch_size):
        if first_high_kl_divergence_idxs[i] == 0:
            first_high_kl_divergence_idxs[i] = max_idxs[i]
    return first_high_kl_divergence_idxs


def get_top_1_difference_fraction(top_1_differences, min_idxs, max_idxs):
    sum_top_1_differences = np.sum(top_1_differences)
    sum_non_padding_tokens = np.sum(max_idxs - min_idxs)
    return sum_top_1_differences / sum_non_padding_tokens


def get_high_kl_divergence_fraction(high_kl_divergences, min_idxs, max_idxs):
    sum_high_kl_divergences = np.sum(high_kl_divergences)
    sum_non_padding_tokens = np.sum(max_idxs - min_idxs)
    return sum_high_kl_divergences / sum_non_padding_tokens


# def get_second_new_line_token_idx(text):
#     return text.find("\n", text.find("\n") + 1)


"""
def get_idx_of_first_token_difference_concrete(unsteered_text, steered_text, start_idx):
    for i in range(start_idx, len(unsteered_text)):
        if unsteered_text[i] != steered_text[i]:
            return i - start_idx
    return len(unsteered_text) - start_idx


def get_idx_of_first_token_difference(tokenizer, unsteered_texts, steered_texts):
    second_new_line_token_idxs = [
        get_second_new_line_token_idx(text) for text in unsteered_texts
    ]
    idx_of_first_token_difference = np.array(
        [
            get_idx_of_first_token_difference_concrete(
                unsteered_text, steered_text, second_new_line_token_idx
            )
            for unsteered_text, steered_text, second_new_line_token_idx in zip(
                unsteered_texts, steered_texts, second_new_line_token_idxs
            )
        ]
    )
    max_idx = np.array(
        [
            len(unsteered_text) - second_new_line_token_idx
            for unsteered_text, second_new_line_token_idx in zip(
                unsteered_texts, second_new_line_token_idxs
            )
        ]
    )
    return idx_of_first_token_difference, max_idx
"""


def get_prompts(texts):
    return ["A rhyming couplet:\n" + text for text in texts]


def setup_output_directory(mode, output_dir, model_name, rhyme_family1, rhyme_family2):
    """Create output directory structure."""
    exp_dir = os.path.join(
        output_dir, mode, model_name, f"{rhyme_family1}_{rhyme_family2}"
    )
    os.makedirs(exp_dir, exist_ok=True)
    return exp_dir


def save_data(data, filepath, use_json=True):
    """Save data to pickle file."""
    if use_json:
        filepath = filepath.replace(".pkl", ".json")
        with open(filepath, "w") as f:
            json.dump(data, f)
    else:
        filepath = filepath.replace(".json", ".pkl")
        with open(filepath, "wb") as f:
            pickle.dump(data, f)


def load_data(filepath, use_json=True):
    """Load data from pickle file."""
    if use_json:
        filepath = filepath.replace(".pkl", ".json")
        with open(filepath, "r") as f:
            return json.load(f)
    else:
        filepath = filepath.replace(".json", ".pkl")
        with open(filepath, "rb") as f:
            return pickle.load(f)


def get_common_args():
    """Get common command line arguments."""
    parser = argparse.ArgumentParser()
    # parser.add_argument("--model_name", default="Llama3.2_3B", help="Model name")
    # parser.add_argument("--model_name", default="Qwen3_8B", help="Model name")
    # parser.add_argument("--model_name", default="Gemma2_9B", help="Model name")
    parser.add_argument("--model_name", default="Gemma3_4B", help="Model name")
    parser.add_argument(
        "--mode",
        default="rhyme_family_steering",
        help="Mode (rhyme_family_steering, specific_word_steering)",
    )
    parser.add_argument("--rhyme_family1", default=None, help="First rhyme family")
    parser.add_argument("--rhyme_family2", default=None, help="Second rhyme family")
    parser.add_argument("--output_dir", default="results", help="Output directory")
    parser.add_argument("--multiplier", type=float, default=None, help="Steering multiplier")
    parser.add_argument(
        "--num_prompts",
        type=int,
        default=None,
        help="Number of prompts to generate (for debugging)",
    )
    parser.add_argument(
        "--LAYER_FRACTION", type=float, default=0.8, help="Model layer fraction to use"
    )
    parser.add_argument(
        "--strip",
        action="store_true",
        default=False,
        help="Whether to steer on the token before the end character such as newline, ignoring end character for steering",
    )
    return parser


def calculate_batch_parameters(
    model,
    tokenizer,
    generation_prompts,
    max_new_tokens=MAX_NEW_TOKENS,
    LAYER_FRACTION=LAYER_FRACTION,
    model_name=None,
    layer=None,
):
    """Calculate batch parameters for generation."""
    batch_size = BATCH_SIZE
    config = get_model_config(model)
    num_layers = config.num_hidden_layers
    layers = list(
        range(
            int(np.ceil(((1 - LAYER_FRACTION) / 2) * num_layers)),
            int(np.ceil((LAYER_FRACTION + ((1 - LAYER_FRACTION) / 2)) * num_layers)),
        )
    )

    batch_size_small = get_batch_size_fitting_in_memory(
        model, tokenizer, generation_prompts, max_new_tokens
    )
    # if model_name == "Llama3.3_70B": maybe smaller batchsize ..
    if "70B" in model_name:
        batch_size_small = 250
    elif "27B" in model_name:
        batch_size_small = 500
    else:
        batch_size_small = 1000
    batch_size_small = min(batch_size_small, batch_size)
    n_prompts = len(generation_prompts)
    batch_size_per_prompt = int(batch_size / n_prompts)
    num_prompts_per_rollout = int(batch_size_small / batch_size_per_prompt)

    # Round down num_prompts_per_rollout to a divisor of n_prompts
    while n_prompts % num_prompts_per_rollout != 0:
        num_prompts_per_rollout -= 1
    batch_size_small = num_prompts_per_rollout * batch_size_per_prompt

    if layer is not None:
        layers = [layer]

    return {
        "batch_size_small": batch_size_small,
        "batch_size_per_prompt": batch_size_per_prompt,
        "num_prompts_per_rollout": num_prompts_per_rollout,
        "n_prompts": n_prompts,
        "num_layers": num_layers,
        "layers": layers,
    }


'''
def calculate_batch_parameters_for_regeneration(model, tokenizer, texts):
    """Calculate batch parameters for regeneration tasks (small max_new_tokens)."""
    batch_size_small = get_batch_size_fitting_in_memory(
        model,
        tokenizer,
        texts,
        3,  # Small max_new_tokens for regeneration
    )
    while BATCH_SIZE % batch_size_small != 0:
        batch_size_small -= 1
    return batch_size_small


def calculate_batch_parameters_for_probs(model, tokenizer, texts):
    """Calculate batch parameters for probability computation (no new tokens)."""
    batch_size_small = get_batch_size_fitting_in_memory(
        model,
        tokenizer,
        texts,
        0,  # No new tokens for prob computation
    )
    while BATCH_SIZE % batch_size_small != 0:
        batch_size_small -= 1
    return batch_size_small
'''
