import os

import torch
from diffusers import StableDiffusionGLIGENPipeline, FluxPipeline
from diffusers.utils import load_image
import argparse
import pandas as pd
import json
from tqdm import tqdm
import re

def main(arg, pipe):

    with open(arg.info_dir, "r") as file:
        eval_datum = json.load(file)["data"]

    if not os.path.exists(arg.output_dir):
        os.makedirs(arg.output_dir)

    for eval_data in tqdm(eval_datum[:]):
        prompt =  eval_data["prompt"]
        image = pipe(
            prompt,
            height=1024,
            width=1024,
            guidance_scale=3.5,
            num_inference_steps=50,
            max_sequence_length=512,
            generator=torch.Generator("cpu").manual_seed(0)
        ).images[0]
        image.save(arg.output_dir + "/{:}.png".format(eval_data["id"]))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--cuda", type=int, default=7)
    parser.add_argument("--few_shot_layout", type=int, default=4)
    parser.add_argument("--info_dir", type=str, default="")
    parser.add_argument("--output_dir", type=str, default="image_gen")
    parser.add_argument("--layout_name", type=str, default="layout_raw")
    parser.add_argument("--num_repeat", type=int, default=1)
    arg = parser.parse_args()
    from accelerate.utils import set_seed
    set_seed(0)
    
    cuda_number = arg.cuda
    if cuda_number == -1:
        cur_device = 'cpu'
    else:
        if torch.cuda.is_available():
            cur_device = "cuda:" + str(cuda_number)
        elif torch.backends.mps.is_available():
            cur_device = "mps"
        else:
            cur_device = "cpu"

    pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
    pipe.to(cur_device)
    pipe.set_progress_bar_config(disable=True)
    main(arg, pipe)
