import os
import json
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
from peft import PeftModel
from datasets import load_dataset
import time
from zoneinfo import ZoneInfo

def main():
    parser = argparse.ArgumentParser(description="Inference Script.")
    parser.add_argument("--model_folder", default='/root/model/llama-gsm8k-ft',
                        help="Path or Hugging Face name of the base model.")
    parser.add_argument("--lora_folder", 
                        help="Path to the LoRA adapter folder to merge. Leave empty to use base model.")
    parser.add_argument("--instruction_path", default='BeaverTails',
                        help="Path to the instruction JSON file or 'BeaverTails' to use the dataset.")
    parser.add_argument("--output_path", default='../eval_results/safety/llama.json',
                        help="Path to save the output JSON file.")
    parser.add_argument("--batch_size", type=int, default=8, help="Inference batch size. Adjust based on your VRAM.")
    parser.add_argument("--max_samples", type=int, default=1000, help="Maximum number of samples to process.")


    args = parser.parse_args()
    print(f"Running with arguments: {args}")


    output_folder = os.path.dirname(args.output_path)
    if not output_folder:
        output_folder = '.'
    os.makedirs(output_folder, exist_ok=True)
    print(f"Output will be saved to {args.output_path}")

    instruction_lst = []
    input_data_lst = []

    if "BeaverTails" in args.instruction_path:
        print("Loading BeaverTails dataset...")
        dataset = load_dataset('json', data_files="../data/beavertail_30k/test.jsonl", split="train")
        
        for example in dataset:
            if len(instruction_lst) < args.max_samples and not example["is_safe"]:
                instance = {"instruction": example["prompt"]}
                instruction_lst.append(example["prompt"])
                input_data_lst.append(instance)
    else:
        print(f"Loading instructions from {args.instruction_path}...")
        with open(args.instruction_path, 'r', encoding='utf-8') as f:

            all_data = json.load(f)
            input_data_lst = all_data[:args.max_samples]
            instruction_lst = [data['instruction'] for data in input_data_lst]

    print(f"Loaded {len(instruction_lst)} instructions.")


    print(f"Loading tokenizer from {args.model_folder}...")
    tokenizer = AutoTokenizer.from_pretrained(args.model_folder)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = 'left'
    print(f"Tokenizer setup for batching. PAD token: '{tokenizer.pad_token}', Padding side: '{tokenizer.padding_side}'")

    print(f"Loading model '{args.model_folder}' in bfloat16 mode...")
    model = AutoModelForCausalLM.from_pretrained(
        args.model_folder,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        attn_implementation="flash_attention_2"
    )

    if args.lora_folder and os.path.exists(args.lora_folder):
        print(f"Loading and merging LoRA weights from {args.lora_folder}...")
        model = PeftModel.from_pretrained(model, args.lora_folder)
        model = model.merge_and_unload()
        print("LoRA weights merged successfully.")
    else:
        print("No LoRA folder provided or found. Using the base model.")

    model.eval()
    print("Model set to evaluation mode (model.eval()).")


    output_lst = []
    total_count = len(instruction_lst)
    start_time = time.time()

    with torch.no_grad():
        for i in tqdm(range(0, total_count, args.batch_size), desc="Batch Generating Responses"):
            batch_instructions = instruction_lst[i : i + args.batch_size]
            
            batch_messages = [[{"role": "user", "content": inst}] for inst in batch_instructions]
           
            prompts_text = [
                tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
                for msg in batch_messages
            ]
            model_inputs = tokenizer(
                prompts_text,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=1024
            ).to(model.device)
            
            generated_ids = model.generate(
                **model_inputs,
                do_sample=False,
                num_beams=1,
                max_new_tokens=256
            )
            
            prompt_lengths = model_inputs['input_ids'].shape[1]
            response_ids = generated_ids[:, prompt_lengths:]
            responses_text = tokenizer.batch_decode(response_ids, skip_special_tokens=True)
            
            print(responses_text)

            for j, response in enumerate(responses_text):
                original_data_index = i + j
                output_data = input_data_lst[original_data_index].copy()
                output_data['output'] = response.strip()
                output_lst.append(output_data)

    end_time = time.time()
    total_time = end_time - start_time
    samples_per_second = total_count / total_time if total_time > 0 else 0

    print("Saving results...")
    with open(args.output_path, 'w', encoding='utf-8') as f:
        json.dump(output_lst, f, indent=4, ensure_ascii=False)

if __name__ == "__main__":
    main()