import pickle
import os
from collections import defaultdict
import csv
from dsg.query_utils import generate_dsg
from dsg.parse_utils import parse_tuple_output, parse_dependency_output, parse_question_output

from vllm import LLM, SamplingParams

import submitit

CHUNK_NUMS = 15


def get_dsg_prompts(annot_file: str, complexity: int, chunk_id: int) -> dict:
    """
    Load prompts from the annotation file.
    The file should contain a dictionary with keys as IDs and values as prompts.
    """
    if annot_file.endswith('.pkl'):
        with open(annot_file, 'rb') as f:
            data = pickle.load(f)
        if "dci" in annot_file:
            skip_interval = 50
            assert complexity == 0, f"complexity should be 0 for dci dataset, but got {complexity}"
        elif "gemma3" in annot_file:
            skip_interval = 20
    else:
        raise ValueError(f"Unsupported file format: {annot_file}. Please provide a .pkl file.")

    chunk_size = len(data) // CHUNK_NUMS
    if len(data) % CHUNK_NUMS != 0:
        chunk_size += 1
    # chunk_size = 60
    assert chunk_id >= 0 and chunk_id <= CHUNK_NUMS, f"chunk_id should be in range [0, {CHUNK_NUMS}]"

    id2prompts = defaultdict(dict)
    for i, index in enumerate(range(chunk_id * chunk_size, (chunk_id + 1) * chunk_size, skip_interval)):
        key = f'complexity_{complexity}_caption_{index // skip_interval}'
        metadata = defaultdict(dict)
        metadata['input'] = data[index]['caps'][complexity]
        id2prompts[key] = metadata
    return id2prompts


def dsg_generation(llm_model: str,
                   gpu_nums: int,
                   annot_file: str,
                   complexity: int,
                   output_file: str,
                   chunk_id: int) -> None:

    llm = LLM(model=llm_model, dtype='bfloat16', tensor_parallel_size=gpu_nums)

    def llama_completion(prompt: str, llm=llm) -> str:
        sampling_params = SamplingParams(
            temperature=0.6,
            # top_p=0.95,
            max_tokens=3000,)
        response = llm.generate(prompt, sampling_params)[0]
        return response.outputs[0].text.strip()

    openai_completion = llama_completion

    id2prompts = get_dsg_prompts(annot_file, complexity, chunk_id)
    # print(id2prompts)

    id2tuple_outputs, id2question_outputs, id2dependency_outputs = generate_dsg(
        id2prompts,
        # you can change this method with any method that
        # takes prompt as input and outputs LLM generation result.
        generate_fn=openai_completion
    )

    # print("ID2tuple outputs: ", id2tuple_outputs)
    # print("ID2question outputs: ", id2question_outputs)
    # print("ID2dependency outputs: ", id2dependency_outputs)

    csv_outputs = []
    header = ["item_id",
              "text",
              "keywords",
              "proposition_id",
              "dependency",
              "category_broad",
              "category_detailed",
              "tuple",
              "question_natural_language"]
    csv_outputs.append(header)
    for i, key in enumerate(id2prompts.keys()):
        csv_items = []
        try:
            qid2tuple = parse_tuple_output(id2tuple_outputs[key]['output'])
            qid2dependency = parse_dependency_output(id2dependency_outputs[key]['output'])
            qid2question = parse_question_output(id2question_outputs[key]['output'])
        except Exception as e:
            print("*" * 50)
            print(f"Error parsing outputs for key {key}: {e}")
            print(f"Key is {key}")
            print(f"Caption is {id2prompts[key]['input']}")
            print(f"ID2tuple: {id2tuple_outputs[key]}")
            print(f"ID2dependency: {id2dependency_outputs[key]}")
            print(f"ID2question: {id2question_outputs[key]}")
            print("*" * 50)
            # If parsing fails, skip this key
            continue
        try:
            for qid in qid2tuple.keys():
                item_id = key
                text = id2prompts[key]['input']
                keywords = None
                proposition_id = qid
                dependency = ",".join(map(str, qid2dependency[qid]))
                category_broad = qid2tuple[qid].split('-')[0].strip()
                category_detailed = qid2tuple[qid].split('-')[-1].strip()
                tuple_ = None
                question_natural_language = qid2question[qid]
                csv_item = [item_id,
                            text,
                            keywords,
                            proposition_id,
                            dependency,
                            category_broad,
                            category_detailed,
                            tuple_,
                            question_natural_language]
                csv_items.append(csv_item)
        except Exception as e:
            print("*" * 50)
            print(f"Error parsing outputs for key {key}: {e}")
            print(f"Key is {key}")
            print(f"Caption is {id2prompts[key]['input']}")
            print(f"ID2tuple: {id2tuple_outputs[key]}")
            print(f"ID2dependency: {id2dependency_outputs[key]}")
            print(f"ID2question: {id2question_outputs[key]}")
            print("*" * 50)
            # If parsing fails, skip this key
            continue
        csv_outputs.extend(csv_items)

    save_file = output_file + f"_chunk_{chunk_id}.csv"
    with open(save_file, 'w', newline='') as file:
        # Create a CSV writer object
        writer = csv.writer(file)

        # Write the data rows
        writer.writerows(csv_outputs)
    return


if __name__ == "__main__":
    llm_model = "google/gemma-3-27b-it"  # "meta-llama/Llama-3.3-70B-Instruct"  # "meta-llama/Meta-Llama-3.1-8B-Instruct"  # or "google/gemma-3-27b-it"
    gpu_per_node = 4

    succeeded_combos = []
    if llm_model == "google/gemma-3-27b-it":
        save_folder = "dsg_cc12m"
        dest_folder = f"../../outputs/{save_folder}"
        csvs = os.listdir(dest_folder)
        for csv_file in csvs:
            compl = int(csv_file.split('_')[2])
            chunk_index = int(csv_file.split('_')[-1].split('.')[0])
            succeeded_combos.append((compl, chunk_index))
    else:
        save_folder = "dsg_cc12m_llama"
        dest_folder = f"../../outputs/{save_folder}"
        csvs = os.listdir(dest_folder)
        for csv_file in csvs:
            compl = int(csv_file.split('_')[2])
            chunk_index = int(csv_file.split('_')[-1].split('.')[0])
            succeeded_combos.append((compl, chunk_index))

    executor = submitit.AutoExecutor(folder="../../outputs/dsg_logs/")
    executor.update_parameters(
        timeout_min=int(60*24),
        mem_gb=100,
        name="dsg_temp",
        slurm_array_parallelism=10,
        slurm_nodes=1,
        slurm_gpus_per_node=gpu_per_node,
        slurm_tasks_per_node=1,
        slurm_cpus_per_task=5,
    )
    with executor.batch():
        for chunk_id in range(CHUNK_NUMS + 1):
            for complexity in range(4):
                if chunk_id < CHUNK_NUMS:
                    continue
                if (complexity, chunk_id) in succeeded_combos:
                    print(f"Skip job for complexity {complexity} and chunk {chunk_id} as it is already done.")
                    continue
                print(f"Submit job for complexity {complexity} and chunk {chunk_id}")
                job = executor.submit(dsg_generation,
                                      llm_model,
                                      gpu_per_node,
                                      "../../metadata/cc12m/full_dict_gemma3_siglip_eval_clean_5k_4caps.pkl",
                                      complexity,
                                      f"../../outputs/{save_folder}/dsg_complexity_{complexity}",
                                      chunk_id)
