import json
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import argparse
from typing import List, Dict
import numpy as np

def load_model(model_id, torch_dtype=torch.bfloat16, device_map="auto"):
    """
    Loads the tokenizer and model.

    Args:
        model_id (str): The Hugging Face model ID.
        torch_dtype: The desired torch data type.
        device_map (str): Device mapping strategy.

    Returns:
        text_generator: The Hugging Face pipeline for text generation.
    """
    print(f"Loading model: {model_id}")
    print(f"Loading tokenizer for model: {model_id}")
    
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    
    # Check if pad_token is already defined
    if tokenizer.pad_token is None:
        print("Tokenizer does not have a pad_token. Setting pad_token to eos_token.")
        if tokenizer.eos_token is not None:
            tokenizer.pad_token = tokenizer.eos_token
        else:
            # Define a new pad_token if eos_token is not available
            tokenizer.add_special_tokens({'pad_token': '[PAD]'})
            print("Added a new pad_token '[PAD]' to the tokenizer.")
    
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch_dtype,
        device_map=device_map,
        trust_remote_code=True  # Set to True if the model requires custom code
    )
    
    # If a new pad_token was added, resize model embeddings
    if tokenizer.pad_token not in tokenizer.get_added_vocab():
        print("Resizing model embeddings to accommodate the new pad_token.")
        model.resize_token_embeddings(len(tokenizer))
    
    # Ensure the model knows the pad_token_id
    if model.config.pad_token_id is None:
        model.config.pad_token_id = tokenizer.pad_token_id
        print(f"Set model.config.pad_token_id to {model.config.pad_token_id}")
    
    print("Initializing text generation pipeline...")
    text_generator = pipeline(
        "text-generation",
        tokenizer=tokenizer,
        model=model,
        torch_dtype=torch_dtype,
        device_map=device_map,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,  # Explicitly pass pad_token_id
    )
    
    # Optional: Verify that pad_token_id is set correctly
    if text_generator.tokenizer.pad_token_id is None:
        text_generator.tokenizer.pad_token_id = model.config.pad_token_id
        print(f"Set text_generator.tokenizer.pad_token_id to {text_generator.tokenizer.pad_token_id}")
    
    return text_generator

def create_batch_prompts(captions: List[str]) -> List[Dict]:
    """
    Creates a batch of prompts in the correct format for the model.

    Args:
        captions (List[str]): List of original captions to be refined.

    Returns:
        List[Dict]: List of formatted message dictionaries.
    """
    messages = []
    for caption in captions:
        instruction = (
            "Please generate an extended, refined version of the following prompt:\n"
            f"{caption}\n\n"
            "The prompt should begin with the phrase of each provided prompt and continue with "
            "a different, vivid description of a specific scene. The prompt should be long and "
            "detailed with different feelings in it. Only answer the prompt part:"
        )
        messages.append([
            {"role": "system", "content": "You are a helpful assistant who are good at follow the instruction to provide specific format of content."},
            {"role": "user", "content": instruction}
        ])
    return messages

def refine_captions_batch(text_generator, captions: List[str], batch_size: int = 8, max_new_tokens: int = 384):
    """
    Generates refined versions of multiple prompts in batches.

    Args:
        text_generator: The Hugging Face text generation pipeline.
        captions (List[str]): List of original captions to be refined.
        batch_size (int): Number of prompts to process simultaneously.
        max_new_tokens (int): Maximum number of tokens to generate per prompt.

    Returns:
        List[str]: List of refined prompts.
    """
    refined_prompts = []
    with tqdm(total=len(captions), desc="Processing captions") as pbar:
        # Process captions in batches
        for i in range(0, len(captions), batch_size):
            batch_captions = captions[i:i + batch_size]
            batch_messages = create_batch_prompts(batch_captions)
            
            # Generate refined prompts for the batch
            outputs = text_generator(
                batch_messages,
                max_new_tokens=max_new_tokens,
                temperature=1,
                batch_size=batch_size,
                pad_token_id=text_generator.tokenizer.pad_token_id
            )
            pbar.update(len(batch_captions))
            # Extract the generated text from each output
            for output, original in zip(outputs, batch_captions):
                refined_prompt = output[0]['generated_text'][-1]
                #print(f"Original prompt: {original}")
                #print(f"Refined prompt: {refined_prompt}\n")
                yield refined_prompt

def process_jsonl_batch(input_path: str, output_path: str, text_generator, batch_size: int = 8):
    """
    Reads the input JSONL file, processes entries in batches, and writes to the output JSONL file.

    Args:
        input_path (str): Path to the input JSONL file.
        output_path (str): Path to the output JSONL file.
        text_generator: The Hugging Face text generation pipeline.
        batch_size (int): Number of entries to process simultaneously.
    """
    # First, read all entries and collect valid captions
    entries = []
    captions = []
    
    with open(input_path, 'r', encoding='utf-8') as infile:
        for line in tqdm(infile, desc="Reading entries"):
            try:
                data = json.loads(line)
                entry = data.get("entry", {})
                caption = entry.get("caption", "").strip()

                if not caption:
                    print(f"Skipping entry with missing caption: {entry}")
                    continue

                entries.append(data)
                captions.append(caption)

            except json.JSONDecodeError as e:
                print(f"Skipping invalid JSON line: {e}")
                continue

    # Process captions in batches
    refined_prompts = refine_captions_batch(
        text_generator,
        captions,
        batch_size=batch_size
    )

    # Write results to output file
    with open(output_path, 'w', encoding='utf-8') as outfile:
        for data, refined in zip(entries, refined_prompts):
            data["result1"] = {"generated": refined}
            outfile.write(json.dumps(data, ensure_ascii=False) + '\n')

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Batch refine captions in a JSONL file using LLaMA 3.2 1B Instruct model.")
    parser.add_argument('--input', type=str, required=True, help='Path to the input JSONL file.')
    parser.add_argument('--output', type=str, required=True, help='Path to the output JSONL file.')
    parser.add_argument('--model_id', type=str, default="meta-llama/Llama-3.2-1B-Instruct", help='Hugging Face model ID.')
    parser.add_argument('--batch_size', type=int, default=8, help='Number of prompts to process simultaneously.')
    parser.add_argument('--max_new_tokens', type=int, default=256, help='Maximum number of tokens to generate per prompt.')

    args = parser.parse_args()

    # Load the model and tokenizer
    text_generator = load_model(args.model_id)

    # Process the input file and generate the output
    process_jsonl_batch(
        args.input,
        args.output,
        text_generator,
        batch_size=args.batch_size
    )

    print(f"Processing complete. Refined captions written to {args.output}")