import argparse
import os
import re
import json
import random
import torch
from transformers import  AutoTokenizer
from modeling_utils.modeling_qwen2 import Qwen2ForCausalLM
from datasets import load_dataset
import os
from tqdm import tqdm


def main():
    random.seed(42)

    print("Loading data...")
    test_data = []
    data_dir = "/path/to/seal_eval_data"
    files = [f for f in os.listdir(data_dir)]
    def get_checkpoint_number(folder):
        try:
            return int(folder.split('.')[0].split("_")[0])
        except ValueError:
            return 0
    files.sort(key=get_checkpoint_number)
    for file in files:
        with open(os.path.join(data_dir, file)) as f:
            data = json.load(f)
            test_data.extend(data)

    save_dir = "/path/to/save/eval_results"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    model_path = "/path/to/DeepSeek-R1-Distill-Qwen-14B"
    tokenizer = AutoTokenizer.from_pretrained(model_path)

     # set padding side to left for batch generation
    tokenizer.padding_side = "left"

    # set pad token to eos token if pad token is not set (as is the case for llama models)
    #if tokenizer.pad_token is None:
    #   tokenizer.pad_token = tokenizer.eos_token
    #   tokenizer.pad_token_id = tokenizer.eos_token_id


    model = Qwen2ForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
    
    steer_vec_path = "/path/to/your/steer_vector/steer_vec.pt"
    steer_vec = torch.load(steer_vec_path, weights_only=True)
    steer_vec = steer_vec.to(model.device)
    model.set_steering_flag(steering_flag=True, steering_layer=25, steer_vec=steer_vec,  steer_coef=1.0, tokenizer=tokenizer)

    mini_batch_size = 10
    batch_size = 1
    for i in range(0, len(test_data), mini_batch_size):
        start_pos = i
        end_pos = min(i + mini_batch_size, len(test_data))
        mini_batch = test_data[start_pos:end_pos]
        print(f"Processing data {start_pos} to {end_pos}...")
        for i in tqdm(range(0, len(mini_batch), batch_size)):
            print(f"Processing data {start_pos + i}...")
            model.start_new_round()
            batch = [item["prompt"] for item in mini_batch[i:i+batch_size]]
            tokenized_batch = tokenizer(batch, return_tensors="pt", padding=True)
            tokenized_batch = {k: v.to(model.device) for k, v in tokenized_batch.items()}
            with torch.no_grad():
                output = model.generate(**tokenized_batch, max_new_tokens=4096, use_cache=True, pad_token_id=tokenizer.eos_token_id)
            prompt_len = tokenized_batch["input_ids"].shape[1]
            output = [tokenizer.decode(o[prompt_len:], skip_special_tokens=True) for o in output]
            mini_batch[i]["model_output"] = output[0]

        for item in mini_batch:
            output_length = len(tokenizer(item["model_output"])["input_ids"])
            item["output_length"] = output_length
    
        with open(os.path.join(save_dir, f"{start_pos}_{end_pos}.json"), "w", encoding="utf-8") as f:
            json.dump(mini_batch, f, ensure_ascii=False, indent=4)
    


if __name__ == "__main__":
    main()