import pandas as pd
import numpy as np
# from fastchat.conversation import Conversation, SeparatorStyle
from transformers import AutoTokenizer
import os
import sys
import srsly
import fire
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch
import time

def main(
        model_path = "./model/qwen2.5-1.5b",
        test_file = "./data/train_safe_token.json",
        out_path = "./generate_embd/qwen2.5-1.5b_emb_res.pt",
        batch_size=256 
        ):
    system_prompt = "You are now a helpful personal AI assistant."
    def format_with_qwen(data):
        format_input = []
        for sample in data:
            user_input = sample["question"]
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_input}
                ]
            format_input.append(messages)
        return format_input

    model_name = model_path.split('/')[-1]
    model = AutoModelForCausalLM.from_pretrained(model_path,
                                                torch_dtype=torch.float16,
                                                device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    
    data = srsly.read_json(test_file)

    text_ls = format_with_qwen(data)
    emb_res = []
    label_res = []

    for i in tqdm(range(0, len(text_ls), batch_size), desc="Processing batches"):
        batch_text = text_ls[i:i + batch_size]

        user_input = tokenizer.apply_chat_template(
                                        batch_text,
                                        tokenize=False,
                                        add_generation_prompt=True
                                    )
        
        input_ids = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True, max_length=768)['input_ids'].to(model.device)

        with torch.no_grad():
            input_embeddings = model.get_input_embeddings()(input_ids)
            batch_emb_res = torch.mean(input_embeddings, dim=1).cpu().numpy()
            emb_res.extend(batch_emb_res)
            
            start_time = time.time() 
            model_inputs = input_embeddings
            end_time = time.time()
            runtime = end_time - start_time
            print(f"Batch {i // batch_size + 1} processed in {end_time - start_time:.2f} seconds.")

    
    for item in data:
        label_res.append(item['safe'])

    res = []
    for item in zip(emb_res,label_res):
        sample = {
            "tensor":item[0],
            "label":item[1],
        }
        res.append(sample)
    torch.save(res, out_path)

if __name__ == "__main__":
    fire.Fire(main)