import torch
import argparse
from pathlib import Path
import sys
import json
from tqdm import tqdm
import os

from diffusers import StableDiffusion3Pipeline
from sd3_forward_with_bbs import modified_sd3
from precomputed import PrecomputedBoundingBox
from sd3_constrain_attention import ConstrainedJointAttnProcessor2_0
from preprocessing import add_text_length_to_bbs


# Adjust these paths to your own
TOKEN = "<hf-token>"
SD3_PATH = "stabilityai/stable-diffusion-3.5-large"


if __name__ == "__main__":
    # Parse cmd args
    parser = argparse.ArgumentParser()
    parser.add_argument("--prompt", type=str, required=True)
    parser.add_argument("--image_out_path", type=str, default='compositional_generation/complex_img.png')
    parser.add_argument("--bbox_file", type=str, default='compositional_generation/bounding_boxes.json')
    parser.add_argument("--num_inference_steps", type=int, default=50)
    parser.add_argument("--kernel_size", type=int, default=25)
    parser.add_argument("--latent_height", type=int, default=32)
    parser.add_argument("--latent_width", type=int, default=32)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--combine_timestep", type=int, default=10)
    parser.add_argument("--use_tight_segmentation", action='store_true')
    parser.add_argument("--blur_kernel", type=int, default=5)
    parser.add_argument("--seg_head", type=int, default=20)
    parser.add_argument("--seg_block", type=int, default=14)
    parser.add_argument("--seg_thresh", type=float, default=0.95)
    cmd_args = parser.parse_args()

    if cmd_args.kernel_size % 2 == 0:
        raise ValueError("Kernel size must be odd")

    if cmd_args.use_tight_segmentation:
        segmentation_dict = {'head': cmd_args.seg_head, 'block': cmd_args.seg_block, 'thresh': cmd_args.seg_thresh}
    else:
        segmentation_dict = None

    # load SD3.5 pipeline
    image_generation_pipeline = StableDiffusion3Pipeline.from_pretrained(SD3_PATH, torch_dtype=torch.float16).to("cuda")

    for i, block in enumerate(image_generation_pipeline.transformer.transformer_blocks):
        block.attn.processor = ConstrainedJointAttnProcessor2_0(block_id=i)

    prompt = cmd_args.prompt
    bb_gen = PrecomputedBoundingBox(cmd_args.bbox_file, cmd_args.latent_height, cmd_args.latent_width)

    bbs = bb_gen(prompt)
    bbs = add_text_length_to_bbs(bbs, image_generation_pipeline.tokenizer_2, prompt, segmentation_dict, cmd_args.blur_kernel)

    for i in range(cmd_args.batch_size):
        try:
            imgs = modified_sd3(
                image_generation_pipeline,
                cmd_args.combine_timestep,
                bbs,
                cmd_args.use_tight_segmentation,
                prompt,
                num_inference_steps=cmd_args.num_inference_steps,
                width=512,
                height=512,
                guidance_scale=3.5,
                num_images_per_prompt=1,
            )
        except Exception as e:
            raise e

        img = imgs.images[0]

        img.save(cmd_args.image_out_path + f"{i}.png")
