import os
import pickle
import random
import shutil
import tempfile
from tqdm import tqdm
import numpy as np
import torch
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, set_seed

import submitit

DATA_PATH = "../metadata/"


def generate(chunk_id, chunk_size, object_num):
    print(f"chunk_id: {chunk_id}, chunk_size: {chunk_size}, object_num: {object_num}")
    random.seed(0)
    set_seed(0)
    torch.manual_seed(0)
    np.random.seed(0)
    with open(f"{DATA_PATH}/cc12m/full_cc12m_eval_clean.pkl", "rb") as f:
        image_names = pickle.load(f)
    selected_image_names = image_names[chunk_id * chunk_size: (chunk_id + 1) * chunk_size]

    model_id = "google/gemma-3-12b-it"

    model = Gemma3ForConditionalGeneration.from_pretrained(
        model_id, device_map="cuda"
    ).eval()

    processor = AutoProcessor.from_pretrained(model_id)

    if object_num == 0:
        user_prompt = "Mention the main subjects in the image with only 1 word. Do not use specific identifiers, such as names. Format the caption as such: <<CAPTION>>" 
    elif object_num == 1:
        user_prompt = "Write a short caption with 3 or less words. Use only one adjective. Do not use specific identifiers, such as names. Format the caption as such: <<CAPTION>>" 
    elif object_num == 2:
        user_prompt = "Write a caption with 6 or less words, describing at most two main foreground subjects. Do not use specific identifiers, such as names. Format the caption as such: <<CAPTION>>"
    elif object_num == 3:
        user_prompt = "Write a caption with 8 or less words, describing at most two main foreground subjects, as well as the background. Do not use specific identifiers, such as names. Format the caption as such: <<CAPTION>" 
    else:
        raise ValueError("object_num should be 0, 1 or 2 or 3")
    gen_captions = []
    for image_name in tqdm(selected_image_names):

        messages = [
            {
                "role": "system",
                "content": [{"type": "text",
                             "text": "You strictly follow the user's instructions."}]
            },
            {
                "role": "user",
                "content": [
                    {"type": "image",
                     "image": f"{DATA_PATH}/"
                              f"cc12m/images/{image_name}.jpg"},
                    {"type": "text", "text": user_prompt}
                ]
            }
        ]

        try:
            inputs = processor.apply_chat_template(
                messages, add_generation_prompt=True, tokenize=True,
                return_dict=True, return_tensors="pt"
            ).to(model.device, dtype=torch.float16)
        except SyntaxError:
            print(f"SyntaxError: {image_name}")
            caption = "GENERATION FAILED"
            gen_captions.append((image_name, caption))
            continue
        except Exception as e:
            caption = "GENERATION FAILED"
            gen_captions.append((image_name, caption))
            print(e)
            continue

        input_len = inputs["input_ids"].shape[-1]

        with torch.inference_mode():
            generation = model.generate(**inputs,
                                        max_new_tokens=100,
                                        do_sample=True,
                                        temperature=0.6,)
            try:
                generation = generation[0][input_len:]
            except IndexError:
                generation = "GENERATION FAILED"
            except Exception as e:
                print(e)
                generation = "GENERATION FAILED"
        if generation == "GENERATION FAILED":
            caption = "GENERATION FAILED"
        else:
            decoded = processor.decode(generation, skip_special_tokens=True)
            try:
                caption = decoded.split('<<')[1].split('>>')[0].strip()
            except IndexError:
                caption = "GENERATION FAILED"
            except Exception as e:
                print(e)
                caption = "GENERATION FAILED"
        gen_captions.append((image_name, caption))
    
    os.makedirs(f"{DATA_PATH}/cc12m/gemma3_captions", exit_ok=True)
    with tempfile.TemporaryDirectory() as tmpdirname:
        save_file = f"{tmpdirname}/gemma3_{object_num}obj_{chunk_id}.txt"
        with open(save_file, "a+") as f:
            for image_name, caption in gen_captions:
                f.write(f"{image_name}*****{caption}\n")
    
        shutil.copy(save_file, f"{DATA_PATH}/cc12m/gemma3_captions/gemma3_{object_num}obj_{chunk_id}.txt")
        

if __name__ == "__main__":
    with open(f"{DATA_PATH}/cc12m/full_cc12m_eval_clean.pkl", "rb") as f:
        image_names = pickle.load(f)

    chunk_size = 5000
    chunk_num = len(image_names) // chunk_size + 1
    chunk_id = 0
    chunked_image_names = image_names[chunk_id * chunk_size: (chunk_id + 1) * chunk_size]

    executor = submitit.AutoExecutor(folder="../logs/gemma3_logs")
    executor.update_parameters(
        timeout_min=int(60 * 12),
        mem_gb=20,
        name="gemma3",
        slurm_array_parallelism=1,
        slurm_nodes=1,
        slurm_gpus_per_node=1,
        slurm_tasks_per_node=1,
        slurm_cpus_per_task=5,
        slurm_partition="",
    )

    with executor.batch():
        for object_num in range(4):
            for i in range(chunk_num):
                print(f"{i} : {i * chunk_size} -> {(i+1) * chunk_size}")
                job = executor.submit(generate, i, chunk_size, object_num)
                print(job)