import pandas as pd
import numpy as np
# from fastchat.conversation import Conversation, SeparatorStyle
from transformers import AutoTokenizer
import os
import sys
import srsly
import fire
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch
import time
import json
from utils import CosineSimilarityNet, format_test_data

def main(
        model_path = "/home/model/qwen2.5-1.5b",
        test_file = "/home/generate_embd/safe_model_pred.json",
        output_path = "./results/safe_model_pred_1b_generate.json",
        batch_size=256 
        ):
    
    system_prompt = "You are now a helpful personal AI assistant."
    def format_with_qwen(data):
        for sample in data:
            user_input = sample["prompt"] if "prompt" in sample else sample["question"] 
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_input}
            ]
            sample['final_prompt'] = messages
        return data

    model_name = model_path.split('/')[-1]
    model = AutoModelForCausalLM.from_pretrained(model_path,
                                                torch_dtype=torch.float16,
                                                device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    try:
        with open(test_file, 'r', encoding='utf-8') as f:
            raw_data = json.load(f)
        
        if isinstance(raw_data, dict) and 'predictions' in raw_data:
            data = raw_data['predictions']
        else:
            data = raw_data
            
        if not isinstance(data, list):
            raise ValueError(f"Invalid data format in {test_file}")
            
        print(f"Loaded {len(data)} samples from {test_file}")
        
    except Exception as e:
        raise RuntimeError(f"Failed to load data from {test_file}: {e}")

    data = format_test_data(data)
    print(f"Successfully formatted {len(data)} samples")
    
    print("\nData format validation:")
    print(f"Total samples: {len(data)}")
    if len(data) > 0:
        print("First sample format:")
        print(json.dumps(data[0], indent=2, ensure_ascii=False)[:500])
    
    data = format_with_qwen(data)
    
    for i in tqdm(range(0, len(data), batch_size), desc="Processing batches"):
        batch_samples = data[i:i + batch_size]
        batch_text = [item['final_prompt'] for item in batch_samples]
        
        try:
            user_input = tokenizer.apply_chat_template(
                batch_text,
                tokenize=False,
                add_generation_prompt=True
            )
            
            model_inputs = tokenizer(
                user_input, 
                return_tensors="pt", 
                padding=True, 
                padding_side='left', 
                truncation=True, 
                max_length=768
            ).to(model.device)

            with torch.no_grad():
                generated_ids = model.generate(
                    **model_inputs, 
                    max_new_tokens=10,
                    do_sample=False,
                    temperature=0,
                )
                
            generated_ids = [
                output_ids[len(input_ids):] 
                for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
            ]
            responses = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

            for j, response in enumerate(responses):
                data[i+j]['slm_response'] = response
                
        except Exception as e:
            print(f"Error processing batch {i//batch_size}: {e}")
            for j in range(len(batch_samples)):
                data[i+j]['slm_response'] = ""

    try:
        with open(output_path, "w", encoding="utf-8") as file:
            json.dump(data, file, ensure_ascii=False, indent=4)
        print(f"\nResults saved to {output_path}")
        print(f"Total processed samples: {len(data)}")
    except Exception as e:
        print(f"Error saving results: {e}")
        backup_path = output_path + ".backup"
        with open(backup_path, "w", encoding="utf-8") as file:
            json.dump(data, file, ensure_ascii=False, indent=4)
        print(f"Results saved to backup file: {backup_path}")


if __name__ == "__main__":
    fire.Fire(main)