import os
import sys
import dotenv
from tqdm import tqdm
import random
import datasets
from openai import OpenAI
import traceback
import argparse

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from utils import (
    has_complex_control_flow,
    has_single_return,
    create_dir_w_timestamp,
    write_csv,
    write_txt,
    write_jsonl,
    load_prompt,
    encode_image,
    response_to_py_program,
    response_to_epics_program,
)

dotenv.load_dotenv()
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
KINDS_DIR = os.path.join(os.path.dirname(__file__), f"../datasets/gqa")

def load_gqa(split):
    dataset_images = datasets.load_dataset("lmms-lab/GQA", f"{split}_all_images", split=split)
    dataset_instructions = datasets.load_dataset("lmms-lab/GQA", f"{split}_all_instructions", split=split)
    image_mappings = {item["id"]: item["image"] for item in dataset_images}
    random.seed(2025)
    num_sample=1000
    indices = random.sample(range(len(dataset_instructions)), len(dataset_instructions))[0:num_sample]
    return image_mappings, dataset_instructions, indices


def save_gqa(kind_name, language, model_name, split):
    kind_dir = os.path.join(KINDS_DIR, kind_name)
    prog_out_dir = os.path.join(kind_dir, f"progs_{language}")
    prompt_path = os.path.join(kind_dir, "prompt.txt")
    prompt = load_prompt(prompt_path)

    print("🔧 Loading GQA dataset...")
    image_mappings, dataset_instructions, indices = load_gqa(split)
    
    print(f"📂 Saving results to: {prog_out_dir}")
    print(f"📜 Loaded prompt from: {prompt_path}")
    print(f"🚀 Starting program generation for {len(indices)} examples using {model_name} ({language})...\n")
    os.makedirs(prog_out_dir, exist_ok=True)
    
    for i in tqdm(indices):
        problem_id = dataset_instructions[i]["id"]
        image_id = dataset_instructions[i]["imageId"]
        # image_pil = image_mappings[image_id]
        question : str = dataset_instructions[i]["question"]

        # This doesn't work if we have other curly brackets
        # program_prompt = prompt.format(question=question)
        program_prompt = question.join(prompt.split("{question}"))
        # base64_image = encode_image(image_pil)

        response = client.chat.completions.create(
            model=model_name,
            messages=[
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": program_prompt,
                        },
                        # {
                        #     "type": "image_url",
                        #     "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
                        # }
                    ]
                }
            ],
            temperature=0,
        )
        model_output = response.choices[0].message.content.strip()
        
        program_path = os.path.join(prog_out_dir, f"{problem_id}.prog")
        err_path = os.path.join(prog_out_dir, f"{problem_id}.err")
        try:
            if language == "epic":
                program = response_to_epics_program(model_output, question)
            elif language == "py":
                program = response_to_py_program(model_output, question)
            else:
                assert False, language
            write_txt(program_path, program)
        except Exception as e:
            write_txt(err_path, "\n".join(traceback.format_exception(e)))
            traceback.print_exception(e)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-k', '--kind', 
        help='Kind of generation', 
        choices=list(f for f in os.listdir(KINDS_DIR) if os.path.isdir(os.path.join(KINDS_DIR, f))), 
        required = True
    )
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument('-e', '--epic', help='Save as EPIC', action='store_true', default=False)
    group.add_argument('-p', '--python', help='Save as Python', action='store_true', default=False)
    args = parser.parse_args()

    lang = "epic" if args.epic else "py"

    save_gqa(kind_name = args.kind, language=lang, model_name = "gpt-4o-mini", split = "val")