import random
import os
import sys
import time
from tqdm import tqdm
import json

from PIL import Image
import numpy as np
from transformers import pipeline, set_seed
import torch
from transformers import file_utils

# Set the cache directory
os.environ['HF_HOME'] = '/media/node4/ssd2/raghul/data/models/'
os.environ['TRANSFORMERS_CACHE'] = '/media/node4/ssd2/raghul/data/models/'
file_utils.TRANSFORMERS_CACHE = '/media/node4/ssd2/raghul/data/models/'
cache_dir = "/media/node4/ssd2/raghul/data/models/"


def to_markdown(text):

    text = text.replace('•', '  *')
    text = text.replace("```json", "").replace("```", "").strip()
    return text

# Set seed for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

set_seed(seed)  # For Hugging Face transformers

print("[/] loading model...")
pipe = pipeline(
    "image-text-to-text",
    model="google/gemma-3n-e4b-it",
    device="cuda",
    torch_dtype=torch.bfloat16,
    model_kwargs={"cache_dir": cache_dir},
    tokenizer_kwargs={"cache_dir": cache_dir}
)

print("[/] model loaded.")

if __name__ == "__main__":

    images_dir = sys.argv[1]
    save_json_dense = sys.argv[2]
    save_json_sparse = sys.argv[3]
    files = os.listdir(images_dir)

    image2sparse_caption = {}
    prompt = """
        Output the individual food ingredients present in the image.

        follow the below example captions for the description format:
        1. burger, avocado, cheese
        2. crispy chicken, quinoa, vegetables
        3. hot vegetable sandwich, dipping sauce
        4. apple pie, powdered sugar, coffee
        5. chili, cornbread

        output only the food description as per the above format.
        """

    # Keep it to one sentence, rich yet not too verbose.
    prompt = """
        You are an expert food photographer and culinary describer. Your task is to provide two types of captions - dense and sparse descriptions for food images.
        1. Concise yet rich and appealing captions for food images. Focus on identifying the main dish, key ingredients, preparation styles (e.g., grilled, fried, braised), and any prominent garnishes or side dishes. Use evocative language to make the food sound appetizing. Do not include any personal opinions or extraneous commentary.

        Examples for this dense type captions:
        - "Dashi broth and crispy tempura in a bowl."
        - "Peachy ham hock in a rustic setting."
        - "spaghetti with sausage and meatballs"
        - Savoring slow-cooked pork belly on a platter

        2. Focus on identifying the main dish, key ingredients, preparation styles (e.g., grilled, fried, braised), and any prominent garnishes or side dishes. Be concise yet rich. Use evocative language to make the food sound appetizing. Do not include any personal opinions or extraneous commentary. Output only the food ingredients in a sparse format.

        Examples for this sparse type captions:
        - "rice, soup, vegetables, meat"
        - "wrap, lettuce, carrot, drink"
        - "sushi, avocado, salmon, ginger"
        - "noodles, rice, tomato, cucumber, chilli, sauce"
        - "egg, lettuce, bread"

        Output in a JSON format with two keys:
        "dense_caption": dense description of the food image as per above instructions.
        "sparse_caption": sparse description of the food image as per above instructions.

        """

    image2dense, image2sparse = {}, {}
    for file in tqdm(files):
        try:
            image = Image.open(os.path.join(images_dir, file)).convert("RGB")

            messages = [
                {
                    "role": "system",
                    "content": [{"type": "text", "text": "You are an expert in identifying food ingredients in images."}]
                },
                {
                    "role": "user",
                    "content": [
                        {"type": "image", "image": image},
                        {"type": "text", "text": prompt}
                    ]
                }
            ]

            start_time = time.time()
            output = pipe(text=messages, max_new_tokens=200, do_sample=False)
            print("[/] time to process", time.time() - start_time)

            response = output[0]["generated_text"][-1]["content"]
            response_data = json.loads(to_markdown(response))
            print("[/] VLM response:", response_data)

            image2dense[file] = response_data["dense_caption"]
            image2sparse[file] = response_data["sparse_caption"]

            with open(save_json_dense, "w") as f:
                json.dump(image2dense, f, indent=2)

            with open(save_json_sparse, "w") as f:
                json.dump(image2sparse, f, indent=2)

            # image2sparse_caption[file] = response
            # with open(results_json, "w") as f:
            #     json.dump(image2sparse_caption, f, indent=2)

        except Exception as e:
            print(e)
            print(file)
