#!/usr/bin/env python3

import ast
import os
import sys
from datetime import datetime

from tqdm import tqdm

# Add the parent directory to sys.path to import shared_utils
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from paper_experiments.shared_utils import *

NUM_PROMPTS_PER_PAIR = BATCH_SIZE // len(RHYME_FAMILY_PAIRS)


def main(
    mode,
    model_name,
    model,
    tokenizer,
    pairs,
    output_dir,
    num_prompts=None,
    LAYER_FRACTION=0.8,
    steering_multiplier=None,
):
    multiplier = steering_multiplier if steering_multiplier is not None else STEERING_MULTIPLIER
    print(f"Finding best layer and token position for {model_name}, steering multiplier {multiplier}")

    # Setup output directory
    exp_dir = os.path.join(output_dir, mode, model_name)
    os.makedirs(exp_dir, exist_ok=True)
    output_file = os.path.join(exp_dir, "best_layer_and_token_pos.json")

    # Check if already completed
    if os.path.exists(output_file):
        print(f"Line generation already completed: {output_file}")
        return

    generation_prompts = []
    for rhyme_family1, rhyme_family2 in pairs:
        generation_prompts += load_prompts(mode, "train", rhyme_family1, model_name)[
            :NUM_PROMPTS_PER_PAIR
        ]

    # Calculate batch parameters
    batch_params = calculate_batch_parameters(
        model,
        tokenizer,
        generation_prompts,
        LAYER_FRACTION=LAYER_FRACTION,
        model_name=model_name,
    )
    batch_size_small = batch_params["batch_size_small"]
    batch_size_per_prompt = batch_params["batch_size_per_prompt"]
    num_prompts_per_rollout = batch_params["num_prompts_per_rollout"]
    n_prompts = batch_params["n_prompts"]
    num_layers = batch_params["num_layers"]
    layers = batch_params["layers"]

    print("PARAMETERS")
    print(f"Batch size: {batch_size_small}")
    print(f"Batch size per prompt: {batch_size_per_prompt}")
    print(f"Number of prompts per rollout: {num_prompts_per_rollout}")
    print(f"Number of prompts: {n_prompts}")
    print(f"Number of layers: {num_layers}")
    print(f"Layers: {layers}")
    print(f"num_prompts_per_pair: {NUM_PROMPTS_PER_PAIR}")

    #assert batch_size_per_prompt == 1

    # Generate unsteered texts
    # print("Generating unsteered texts...")
    """unsteered_texts = generate_steered_output(
        None,
        model,
        tokenizer,
        generation_prompts,
        batch_size_per_prompt,
        num_prompts_per_rollout,
    )
    unsteered_texts = [get_cleaned_up_text(text) for text in unsteered_texts]"""

    # Debug: Print first few unsteered texts
    """print("\n=== FIRST FEW UNSTEERED TEXTS ===")
    for i, text in enumerate(unsteered_texts[:3]):
        print(f"Unsteered {i}:")
        print(repr(text))
        print("---")"""

    # Generate steered texts and steering vectors for each layer
    steered_texts = {}
    # steering_vectors = {}

    scores = {}

    steering_vectors_all_pairs_and_layers = {pair: {} for pair in pairs}

    for strip_newline in [True, False]:
        negative_prompts_of_pair = {
            pair: load_prompts(mode, "train", pair[0], model_name, strip=strip_newline)
            for pair in pairs
        }
        positive_prompts_of_pair = {
            pair: load_prompts(mode, "train", pair[1], model_name, strip=strip_newline)
            for pair in pairs
        }

        if tokenizer.padding_side == "right":
            min_idxs, max_idxs = get_min_and_max_idxs(generation_prompts, tokenizer)
            if strip_newline:
                min_idxs = max_idxs - 1
            token_to_steer = []
            for idx in min_idxs.tolist():
                token_to_steer += [idx] * batch_size_per_prompt
        elif strip_newline:
            token_to_steer = -2
        else:
            token_to_steer = -1

        for layer in tqdm(layers, desc="Generating steered texts"):
            print(f"Processing layer {layer}...")

            steering_vectors = []

            for pair in tqdm(pairs, desc="Generating steering vectors"):
                negative_prompts = negative_prompts_of_pair[pair]
                positive_prompts = positive_prompts_of_pair[pair]

                # Get steering vector
                steering_vector = get_steering_vector_fast(
                    model, tokenizer, negative_prompts, positive_prompts, layer=layer
                )
                steering_vectors.append(steering_vector)
                steering_vectors_all_pairs_and_layers[pair][(strip_newline, layer)] = steering_vector

            # Generate steered text
            steered_text = generate_steered_output(
                steering_vectors,
                model,
                tokenizer,
                generation_prompts,
                batch_size_per_prompt,
                num_prompts_per_rollout,
                layer=layer,
                steering_multiplier=multiplier,
                token_to_steer=token_to_steer,
                num_prompts_per_steering_vector=NUM_PROMPTS_PER_PAIR,
            )
            steered_text = [get_cleaned_up_text(text) for text in steered_text]
            steered_texts[f"{strip_newline}_{layer}"] = steered_text

            # Debug: Print first few steered texts for first layer only
            print(f"\n=== FIRST FEW STEERED TEXTS (Layer {layer}) ===")
            for i, text in enumerate(steered_text[:3]):
                print(f"Steered {i}:")
                print(repr(text))
                print("---")

            scores[f"{strip_newline}_{layer}"] = 0
            for i, (rhyme_family1, rhyme_family2) in enumerate(pairs):
                (
                    last_word_correct_steered_rhyme_family1,
                    last_word_correct_steered_rhyme_family2,
                ) = get_last_word_correct(
                    steered_text[
                        i * NUM_PROMPTS_PER_PAIR : (i + 1) * NUM_PROMPTS_PER_PAIR
                    ],
                    [rhyme_family1, rhyme_family2],
                    num_words=1,
                )
                scores[f"{strip_newline}_{layer}"] += (
                    last_word_correct_steered_rhyme_family2.mean() / len(pairs)
                )

            # Clean up GPU memory after each layer
            cleanup_gpu_memory()
            score = scores[f"{strip_newline}_{layer}"]

            print(f"Completed layer {layer}, score: {score}")

    best_strip_newline, best_layer = [
        ast.literal_eval(x) for x in max(scores, key=scores.get).split("_")
    ]

    # Prepare output data
    output_data = {
        "layers": layers,
        "steered_texts": steered_texts,
        "scores": scores,
        # "steering_vectors": steering_vectors,
        "generation_prompts": generation_prompts,
        "batch_size_per_prompt": batch_size_per_prompt,
        "batch_size_small": batch_size_small,
        "metadata": {
            "model_name": model_name,
            "timestamp": datetime.now().isoformat(),
            "num_layers": num_layers,
            "layer_fraction": LAYER_FRACTION,
            "steering_multiplier": multiplier,
        },
        "best_strip_newline": best_strip_newline,
        "best_layer": best_layer,
    }

    # Save results
    save_data(output_data, output_file)
    print(f"Line generation completed and saved to: {output_file}")

    # save steering vectors
    os.makedirs(os.path.join(exp_dir, "steering_vectors"), exist_ok=True)
    for pair in pairs:
        save_dir = os.path.join(exp_dir, f"{pair[0]}_{pair[1]}", "steering_vectors")
        os.makedirs(save_dir, exist_ok=True)
        for strip_newline in [False, True]:
            for layer in layers:
                save_file = os.path.join(save_dir, f"steering_vector_{strip_newline}_{layer}.pkl")
                if os.path.exists(save_file):
                    print(
                        f"Steering vector for {pair} for strip_newline: {strip_newline} at layer {layer} already exists: {save_file}"
                    )
                    continue
                steering_vector = steering_vectors_all_pairs_and_layers[pair][(strip_newline, layer)]
                torch.save(steering_vector, save_file)
                print(f"Saved steering vector for {pair} for strip_newline: {strip_newline} at layer {layer} to {save_file}")


if __name__ == "__main__":
    parser = get_common_args()
    args = parser.parse_args()

    # Load model once
    print("Loading model...")
    model, tokenizer = get_model(args.model_name)

    if args.mode == "rhyme_family_steering":
        pairs = RHYME_FAMILY_PAIRS
    if args.rhyme_family1 is not None and args.rhyme_family2 is not None:
        pairs = [(args.rhyme_family1, args.rhyme_family2)]
    elif args.mode == "specific_word_steering":
        pairs = SPECIFIC_WORD_PAIRS

    try:
        main(
            args.mode,
            args.model_name,
            model,
            tokenizer,
            pairs,
            args.output_dir,
            num_prompts=args.num_prompts,
            LAYER_FRACTION=args.LAYER_FRACTION,
            steering_multiplier=args.multiplier,
        )
    finally:
        # Clean up model
        del model, tokenizer
        cleanup_gpu_memory()
