#!/usr/bin/env python3

import os
import sys
from datetime import datetime

from tqdm import tqdm

# Add the parent directory to sys.path to import shared_utils
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from paper_experiments.shared_utils import *


def main(
    mode,
    model_name,
    model,
    tokenizer,
    rhyme_family1,
    rhyme_family2,
    output_dir,
    num_prompts=None,
    LAYER_FRACTION=0.8,
    strip_newline=False,
):
    print(
        f"Starting line generation for {model_name}: {rhyme_family1} vs {rhyme_family2}"
    )

    # Setup output directory
    exp_dir = setup_output_directory(
        mode, output_dir, model_name, rhyme_family1, rhyme_family2
    )
    output_file = os.path.join(exp_dir, "generated_lines.json")

    # Check if already completed
    if os.path.exists(output_file):
        print(f"Line generation already completed: {output_file}")
        return

    best_layer_and_token_pos_filepath = os.path.join(
        os.path.dirname(exp_dir), "best_layer_and_token_pos.json"
    )
    if not os.path.exists(best_layer_and_token_pos_filepath):
        print(
            f"Best layer and token position file not found: {best_layer_and_token_pos_filepath}"
        )
        return

    best_layer_and_token_pos_data = load_data(best_layer_and_token_pos_filepath)
    layer = best_layer_and_token_pos_data["best_layer"]
    strip_newline = best_layer_and_token_pos_data["best_strip_newline"]

    # Load rhyme family data
    if mode == "rhyme_family_steering":
        generation_prompts = load_prompts(mode, "test", rhyme_family1)
    elif mode == "specific_word_steering":
        generation_prompts = load_prompts(mode, "test", rhyme_family1, model_name)

    if num_prompts is not None:
        generation_prompts = generation_prompts[:num_prompts]

    # Calculate batch parameters
    batch_params = calculate_batch_parameters(
        model,
        tokenizer,
        generation_prompts,
        LAYER_FRACTION=LAYER_FRACTION,
        model_name=model_name,
        layer=layer,
    )
    batch_size_small = batch_params["batch_size_small"]
    batch_size_per_prompt = batch_params["batch_size_per_prompt"]
    num_prompts_per_rollout = batch_params["num_prompts_per_rollout"]
    n_prompts = batch_params["n_prompts"]
    num_layers = batch_params["num_layers"]
    layers = batch_params["layers"]

    print("PARAMETERS")
    print(f"Batch size: {batch_size_small}")
    print(f"Batch size per prompt: {batch_size_per_prompt}")
    print(f"Number of prompts per rollout: {num_prompts_per_rollout}")
    print(f"Number of prompts: {n_prompts}")
    print(f"Number of layers: {num_layers}")
    print(f"Layers: {layers}")

    # Generate unsteered texts
    print("Generating unsteered texts...")
    unsteered_texts = generate_steered_output(
        None,
        model,
        tokenizer,
        generation_prompts,
        batch_size_per_prompt,
        num_prompts_per_rollout,
    )
    unsteered_texts = [get_cleaned_up_text(text) for text in unsteered_texts]

    # Debug: Print first few unsteered texts
    print("\n=== FIRST FEW UNSTEERED TEXTS ===")
    for i, text in enumerate(unsteered_texts[:3]):
        print(f"Unsteered {i}:")
        print(repr(text))
        print("---")

    # Generate steered texts and steering vectors for each layer
    steered_texts = {}
    steering_vectors = {}

    """
    negative_prompts = load_prompts(
        mode, "train", rhyme_family1, model_name, strip=strip_newline
    )
    positive_prompts = load_prompts(
        mode, "train", rhyme_family2, model_name, strip=strip_newline
    )"""

    if tokenizer.padding_side == "right":
        min_idxs, max_idxs = get_min_and_max_idxs(generation_prompts, tokenizer)
        token_to_steer = []
        for idx in min_idxs.tolist():
            token_to_steer += [idx] * batch_size_per_prompt
    elif strip_newline:
        token_to_steer = -2
    else:
        token_to_steer = -1

    for layer in tqdm(layers, desc="Generating steered texts"):
        print(f"Processing layer {layer}...")

        # Get steering vector
        """steering_vector = get_steering_vector_fast(
            model, tokenizer, negative_prompts, positive_prompts, layer=layer
        )"""
        steering_vector_dir = os.path.join(exp_dir, "steering_vectors")
        steering_vector_file = os.path.join(
            steering_vector_dir, f"steering_vector_{strip_newline}_{layer}.pkl"
        )
        steering_vector = torch.load(steering_vector_file)

        # Generate steered text
        steered_text = generate_steered_output(
            steering_vector,
            model,
            tokenizer,
            generation_prompts,
            batch_size_per_prompt,
            num_prompts_per_rollout,
            layer=layer,
            steering_multiplier=STEERING_MULTIPLIER,
            token_to_steer=token_to_steer,
        )
        steered_text = [get_cleaned_up_text(text) for text in steered_text]

        # Debug: Print first few steered texts for first layer only
        print(f"\n=== FIRST FEW STEERED TEXTS (Layer {layer}) ===")
        for i, text in enumerate(steered_text[:3]):
            print(f"Steered {i}:")
            print(repr(text))
            print("---")

        # Store results
        steered_texts[layer] = steered_text
        steering_vectors[layer] = steering_vector

        # Clean up GPU memory after each layer
        cleanup_gpu_memory()

        print(f"Completed layer {layer}")

    # Prepare output data
    output_data = {
        "unsteered_texts": unsteered_texts,
        "layers": layers,
        "strip_newline": strip_newline,
        "steered_texts": steered_texts,
        # "steering_vectors": steering_vectors,
        "generation_prompts": generation_prompts,
        "batch_size_per_prompt": batch_size_per_prompt,
        "batch_size_small": batch_size_small,
        "metadata": {
            "model_name": model_name,
            "rhyme_family1": rhyme_family1,
            "rhyme_family2": rhyme_family2,
            "timestamp": datetime.now().isoformat(),
            "num_layers": num_layers,
            "layer_fraction": LAYER_FRACTION,
            "steering_multiplier": STEERING_MULTIPLIER,
        },
    }

    # Save results
    save_data(output_data, output_file)
    print(f"Line generation completed and saved to: {output_file}")


if __name__ == "__main__":
    parser = get_common_args()
    args = parser.parse_args()

    # Load model once
    print("Loading model...")
    model, tokenizer = get_model(args.model_name)

    if args.mode == "rhyme_family_steering":
        pairs = RHYME_FAMILY_PAIRS
    if args.rhyme_family1 is not None and args.rhyme_family2 is not None:
        pairs = [(args.rhyme_family1, args.rhyme_family2)]
    elif args.mode == "specific_word_steering":
        pairs = SPECIFIC_WORD_PAIRS

    try:
        for rhyme_family1, rhyme_family2 in pairs:
            main(
                args.mode,
                args.model_name,
                model,
                tokenizer,
                rhyme_family1,
                rhyme_family2,
                args.output_dir,
                num_prompts=args.num_prompts,
                LAYER_FRACTION=args.LAYER_FRACTION,
                strip_newline=args.strip,
            )
    finally:
        # Clean up model
        del model, tokenizer
        cleanup_gpu_memory()
