import argparse
import os
import multiprocessing as mp
import json
import random
import time
from tqdm import tqdm
from gen_grids import gen_chess_images
from oai_call import LlmClient


# process-local globals
_LLM = None
_PROMPT = None
_TEMPERATURE = None


def init_worker(model_id, prompt, temperature):
    global _LLM, _PROMPT, _TEMPERATURE
    _LLM = LlmClient(model_id=model_id)
    _PROMPT = prompt
    _TEMPERATURE = temperature


def small_job(image_path):
    for retry in range(5):
        try:
            return _LLM(
                text_prompt=_PROMPT,
                image_path=image_path,
                temperature=_TEMPERATURE,
            )
        except Exception as e:
            code = getattr(e, "status_code", None)
            if code == 429:
                time.sleep(2 ** retry)
            else:
                raise e
    return "ERROR"


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--root_folder",
        type=str,
        default="images",
        help="Folder to save generated chess images.",
    )
    parser.add_argument(
        "--grid_size",
        type=int,
        default=6,
        help="Number of patches along one side of the grid.",
    )
    parser.add_argument(
        "--min_patch_size",
        type=int,
        default=40,
        help="Minimum size of each patch in pixels.",
    )
    parser.add_argument(
        "--max_patch_size",
        type=int,
        default=84,
        help="Maximum size of each patch in pixels.",
    )
    parser.add_argument(
        "--pad2size",
        type=int,
        default=512,
        help="Output image size in pixels.",
    )
    parser.add_argument(
        "--num_examples",
        type=int,
        default=100,
        help="Number of chess images to generate.",
    )
    parser.add_argument(
        "--skip_gen",
        action="store_true",
        help="Whether to skip image generation.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=0,
        help="Random seed for reproducibility.",
    )
    parser.add_argument(
        "--model_id",
        type=str,
        default="gpt4o",
        help="Model ID for the LLM API.",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.01,
        help="Temperature for LLM generation.",
    )
    return parser.parse_args()


def main():
    args = parse_args()

    max_patch_size = args.max_patch_size
    if args.pad2size > 0:
        max_patch_size = min(max_patch_size, args.pad2size // args.grid_size)

    patch_sizes = list(range(
        args.min_patch_size,
        max_patch_size + 1,
        2,
    ))

    for patch_size in tqdm(patch_sizes):
        save_folder = f"g{args.grid_size}_p{patch_size}_pad{args.pad2size}"
        if args.skip_gen:
            continue

        gen_chess_images(
            save_folder=save_folder,
            grid_size=args.grid_size,
            patch_size=patch_size,
            num_examples=args.num_examples,
            pad2size=args.pad2size,
            root_dir=args.root_folder,
            seed=args.seed,
        )

    images = [
        os.path.join(
            args.root_folder,
            f"g{args.grid_size}_p{ps}_pad{args.pad2size}",
            f"{i}.png",
        )
        for ps in patch_sizes
        for i in range(args.num_examples)
    ]

    num_imgs = len(images)
    random.shuffle(images)
    if not args.skip_gen:
        print(f" {num_imgs} images generated.")

    prompt = (
        "There is an N-by-N grid in the image. Each grid cell is filled "
        "with a random color. Observe the grid carefully and find its "
        "grid size."
    )

    print(prompt)
    num_workers = 4 if "qwen" in args.model_id.lower() else 16
    with mp.Pool(
        processes=num_workers,
        initializer=init_worker,
        initargs=(args.model_id, prompt, args.temperature),
    ) as pool:
        responses = list(tqdm(pool.imap(small_job, images), total=len(images)))

    output = list(zip(images, responses))
    output = sorted(output, key=lambda x: x[0])
    with open(f"seed{args.seed}.json", "w", encoding="utf-8") as f:
        json.dump(output, f, indent=2)


if __name__ == "__main__":
    main()
