import torch
import argparse
from pathlib import Path
import sys
import json
from tqdm import tqdm
import os

from diffusers import DiffusionPipeline
from qwen_forward_with_bbs import modified_qwen
from precomputed import PrecomputedBoundingBox
from qwen_constrain_attention import ConstrainedQwenDoubleStreamAttnProcessor2_0
from preprocessing import add_text_length_to_bbs


# Adjust these paths to your own
TOKEN = "<hf-token>"
QWEN_PATH = "Qwen/Qwen-Image"


if __name__ == "__main__":
    # Parse cmd args
    parser = argparse.ArgumentParser()
    parser.add_argument("--prompt", type=str, required=True)
    parser.add_argument("--image_out_path", type=str, default='compositional_generation/complex_img.png')
    parser.add_argument("--bbox_file", type=str, default='compositional_generation/bounding_boxes.json')
    parser.add_argument("--num_inference_steps", type=int, default=50)
    parser.add_argument("--kernel_size", type=int, default=25)
    parser.add_argument("--latent_height", type=int, default=32)
    parser.add_argument("--latent_width", type=int, default=32)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--combine_timestep", type=int, default=10)
    parser.add_argument("--use_tight_segmentation", action='store_true')
    parser.add_argument("--blur_kernel", type=int, default=5)
    parser.add_argument("--seg_head", type=int, default=20)
    parser.add_argument("--seg_block", type=int, default=14)
    parser.add_argument("--seg_thresh", type=float, default=0.95)
    cmd_args = parser.parse_args()

    if cmd_args.kernel_size % 2 == 0:
        raise ValueError("Kernel size must be odd")

    if cmd_args.use_tight_segmentation:
        segmentation_dict = {'head': cmd_args.seg_head, 'block': cmd_args.seg_block, 'thresh': cmd_args.seg_thresh}
    else:
        segmentation_dict = None

    # load Qwen-Image pipeline
    image_generation_pipeline = DiffusionPipeline.from_pretrained(QWEN_PATH, torch_dtype=torch.bfloat16).to("cuda")
    true_cfg_scale = 4.0
    generator = None

    for i, block in enumerate(image_generation_pipeline.transformer.transformer_blocks):
        block.attn.processor = ConstrainedQwenDoubleStreamAttnProcessor2_0(block_id=i)

    prompt = cmd_args.prompt
    bb_gen = PrecomputedBoundingBox(cmd_args.bbox_file, cmd_args.latent_height, cmd_args.latent_width)

    bbs = bb_gen(prompt)
    bbs = add_text_length_to_bbs(bbs, image_generation_pipeline.tokenizer_2, prompt, segmentation_dict, cmd_args.blur_kernel)

    for i in range(cmd_args.batch_size):
        try:
            imgs = modified_qwen(
                image_generation_pipeline,
                cmd_args.combine_timestep,
                bbs,
                cmd_args.use_tight_segmentation,
                prompt,
                num_inference_steps=cmd_args.num_inference_steps,
                width=512,
                height=512,
                num_images_per_prompt=cmd_args.batch_size,
                true_cfg_scale=true_cfg_scale,
                generator=generator,
                negative_prompt=" "
            )
        except Exception as e:
            raise e

        img = imgs.images[0]

        img.save(cmd_args.image_out_path + f"{i}.png")
