import gc
import os
import argparse
import json
import numpy as np
import torch
import matplotlib.pyplot as plt
import re
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from qwen_vl_utils import process_vision_info

np.random.seed(20)
torch.manual_seed(20)
torch.cuda.manual_seed_all(20)

device = torch.device("cuda:2")

def adjust_residual_hook(direction, layer_idx, direction_weight):
    def hook_fn(module, input, output):
        direction_layer = direction[layer_idx].to(output[0].device)
        return (output[0] + direction_weight * direction_layer,) + output[1:]
    return hook_fn

def load_model(model_id, direction_weight=0.0, control="attn"):
    """Load the model and processor with optional steering control."""
    processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        model_id,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
    ).to(device).eval()

    if direction_weight != 0.0:
        direction = torch.load(
            f"/data/R1_Onevision_direction_mmhalu.pt",
            map_location=device
        ).to(device)

        # Add hooks based on control type
        print(f"Adding {control} hooks with direction weight {direction_weight}")
        for i, layer in enumerate(model.model.layers):
            if control == "attn":
                layer.self_attn.register_forward_hook(
                    adjust_residual_hook(direction, i, direction_weight)
                )
            elif control == "mlp":
                layer.mlp.register_forward_hook(
                    lambda module, input, output: output + direction_weight * direction[i].to(output.device)
                )

    return model, processor

def extract_thinking(response, processor):
    """Extracts the thinking part from response text, including the <think> tags."""
    # match = re.search(r"(<think>.*?</think>)", response, re.DOTALL)
    match = re.search(r"(.*?)(?=</think)", response, re.DOTALL)
    if match:
        thinking_text = match.group(1).strip()
        thinking_text = "".join(thinking_text.split())
        
        return thinking_text, len(processor.tokenizer(thinking_text, return_tensors='pt')['input_ids'][0])
    return "", -1


def get_response(model, processor, image_path, question):
    """Process a single image and question pair to get model response."""
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image_path},
                 {"type": "text",
                 "text": f"{question}You FIRST think about the reasoning process as an internal monologue and then provide the final answer. The reasoning process MUST BE enclosed within <think> </think> tags. The final answer MUST BE in <answer> </answer> tags."},
            ],
        }
    ]

    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(device)

    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_new_tokens=4096)
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]

    return output_text

def run_steering_experiment(model, processor, input_json, output_dir, num_samples=50):
    """Run steering control experiment and save results."""
    json_data = json.load(open(input_json, 'r'))
    json_data = json_data[:num_samples]

    thinking_lengths = []
    responses_data = []

    for idx, line in enumerate(json_data):
        image_path = line['image_src']
        image_path = os.path.basename(image_path)
        image_path = os.path.join('/data/MMhalu/images/', image_path)
        question = line['question']

        response = get_response(model, processor, image_path, question)

        thinking_part, thinking_length = extract_thinking(response, processor)
        thinking_lengths.append(thinking_length)

        response_data = {
            "image": os.path.basename(image_path),
            "question": question,
            "response": response,
            "thinking": thinking_part,
            "thinking_length": thinking_length
        }
        responses_data.append(response_data)

        print(f"Processed sample {idx + 1}/{len(json_data)}: thinking length = {thinking_length}")


    os.makedirs(output_dir, exist_ok=True)
    results = {
        "responses": responses_data,
        "think_lengths": thinking_lengths,
        "avg_thinking_length": sum(thinking_lengths) / len(thinking_lengths) if thinking_lengths else 0
    }

    with open(os.path.join(output_dir, "results.json"), "w") as f:
        json.dump(results, f, indent=4)

    plt.figure(figsize=(10, 6))
    plt.hist(thinking_lengths, bins=30, alpha=0.7, edgecolor='black')
    plt.xlabel("Thinking Length (tokens)")
    plt.ylabel("Frequency")
    plt.title("Distribution of Thinking Length After Steering")
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.savefig(os.path.join(output_dir, "thinking_length_distribution.png"))
    plt.close()

    return results

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--input', type=str,
                        default='/data/MMhalu/response_template.json',
                        help='Template file containing images and questions')
    parser.add_argument('--model_id', type=str, default="/data/R1-Onevision_RL/",
                        help='Path to the model')
    parser.add_argument('--num_samples', type=int, default=96,
                        help='Number of samples to process')
    parser.add_argument('--direction_weight', type=float, default=-0.15,
                        help='Weight for thinking length direction steering')
    parser.add_argument('--control', type=str, choices=['attn', 'mlp'], default='attn',
                        help='Type of control to apply (attention or MLP)')
    parser.add_argument('--output_dir', type=str, default=None,
                        help='Directory to save results (default: results/control_type/model_id/weight)')
    args = parser.parse_args()

    if args.output_dir is None:
        args.output_dir = f"/data/R1-Onevision_RL/{args.direction_weight}"

    model, processor = load_model(args.model_id, args.direction_weight, args.control)

    results = run_steering_experiment(model, processor, args.input, args.output_dir, args.num_samples)

    print(f"Average thinking length: {results['avg_thinking_length']:.2f}")
    print(f"Results saved to {args.output_dir}")