import torch
from transformers import MllamaForConditionalGeneration, AutoProcessor
from PIL import Image
import os
import argparse
import pandas as pd
import time
import gc
import csv
from pathlib import Path
import ast
from transformers import AutoTokenizer

def load_model(model_id, gpu_id): #load model
    device=f"cuda:{gpu_id}"
    #model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
    model = MllamaForConditionalGeneration.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map=device
    )
    tokenizer=AutoProcessor.from_pretrained(model_id)
    return model, tokenizer

def process_image(image_path, caption, model, processor, temperatures,gpu_id):   #generation function
    device=f"cuda:{gpu_id}"
    image = Image.open(image_path).convert("RGB")
    #below prompt is for aesthetics, needs to get changes for a different objective
    prompt=f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
                        You are an expert in image analysis and AI-generated art. Your task is to analyze an image's base caption and enhance it for a given  objective while preserving its original details, style, and essence.
                        Generate a refined diffusion model prompt.
                        Prompt: [Generated 60-word prompt, which is extension of base caption]
                        The prompt is super important. Your task is to generate a prompt to create aesthetically enhanced images.
                        <|eot_id|><|start_header_id|>user<|end_header_id|>
                        <|image|>Generate a 60-word diffusion model prompt to generate an enhanced version of the image with improved aesthetics. Ensure all original details are preserved, and the enhancements match the style and essence                          of the original image and should be better in aesthetics.
                        The base caption for the image is: {caption}
                        Output format:                 
                        Caption: [Generated prompt]
                        Make sure to generate a  prompt. The prompt is super important. Return only in the formatted output. 
                        <|eot_id|><|start_header_id|>assistant<|end_header_id|>"""

    
    
    print(prompt)
    inputs = processor(image, prompt, return_tensors="pt").to(model.device)
    results = {}
    for temp in temperatures:
        start_time = time.time()
        outputs = model.generate(
                **inputs,
                max_length=2048,
                temperature=temp,
                do_sample=True,
            )
        results[f'output_temp_{expname}_{temp}'] = processor.decode(outputs[0], skip_special_tokens=True)
    return results

def get_dataset_chunk(csv_path, start_idx, end_idx): #since processing is being done in chunks in multiple jobs
    """Read data from CSV file for the specified chunk"""
    df = pd.read_csv(csv_path)
    df = df.loc[:, ~df.columns.str.contains("^Unnamed")]
    return df.iloc[start_idx:end_idx+1]

def process_chunk(csvpath, expname, model_id, base_image_dir, start_idx, end_idx, gpu_id):
    
    model, processor = load_model(model_id,gpu_id)
    print("model loaded")
    print(f"Processing items {start_idx} to {end_idx} on GPU {gpu_id}")
    mname = model_id.split("/")[-1]
    os.makedirs(f"output_{expname}", exist_ok=True) #output_Dir
    csv_filename = f'output_{expname}/{mname}_{gpu_id}_idx{start_idx}-{end_idx}.csv'
    csv_exists = Path(csv_filename).exists()
    
    data_chunk = get_dataset_chunk(csvpath, start_idx, end_idx) #pass csv file
    
    fieldnames = list(data_chunk.columns)+[f'output_temp_{expname}_0.75', f'output_temp_{expname}_0.85', f'output_temp_{expname}_0.95']
    
    temperatures = [0.75, 0.85, 0.95]
    
    with open(csv_filename, 'a', newline='', encoding='utf-8') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        if not csv_exists:
            writer.writeheader()
        for _, row in data_chunk.iterrows():
            try:
                image_path = os.path.join(base_image_dir, row['filename']) #path to train base images downloaded from hugging face
                caption=ast.literal_eval(row['caption'])[0] #first caption as base caption
                results = process_image(image_path, caption, model, processor, temperatures,gpu_id)
                print("generated")
                row_data = {**row.to_dict(), **results}
                print(row_data)
                writer.writerow(row_data) #write output csv
                
                csvfile.flush()  # Force write to disk
                print(f"Row written for {row['filename']}")
                torch.cuda.empty_cache()
            except Exception as e:
                print(f"Error processing {row['filename']}: {e}")

def main():
    parser = argparse.ArgumentParser(description='Inference with Llama Vision model on Flickr30k dataset')
    parser.add_argument('--csvpath', type=str, required=True, help='Input csv path')
    parser.add_argument('--expname', type=str, required=True, help='name of the experiment, will be used to save outputs')
    parser.add_argument('--base_image_dir', type=str, required=True, help='The base image dir downloaded from hf')
    parser.add_argument('--model_id', type=str, required=True, help='name of the model version to be used')
    parser.add_argument('--gpu_id', type=int, required=True, help='GPU index to run inference on')
    parser.add_argument('--start_idx', type=int, required=True, help='Start index of the dataset chunk')
    parser.add_argument('--end_idx', type=int, required=True, help='End index of the dataset chunk')
    parser.add_argument('--batch_size', type=int, default=4, help='Batch size for processing')

    args = parser.parse_args()
    process_chunk(args.csvpath, args.expname, args.model_id, args.base_image_dir, args.start_idx, args.end_idx, args.gpu_id)

if __name__ == "__main__":
    main()
