from __future__ import annotations
import os, json, argparse, torch
from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler, DiffusionPipeline, CogView4Pipeline, HiDreamImagePipeline, DiTTransformer2DModel, StableDiffusionXLPipeline   
from tqdm import tqdm
from uuid import uuid4
import random

class Config:
    def __init__(self):
        parser = argparse.ArgumentParser(description="Unified HiDream & CogView4 Inference")
        parser.add_argument("--model", choices=["hidream", "cogview4", "sd14", "sd15", "sdxl"], required=True)
        parser.add_argument("--input-file", type=str, required=True)
        parser.add_argument("--output-dir", type=str, default="./OUTPUTS/")
        parser.add_argument("--worker-index", type=int, default=0)
        parser.add_argument("--num-workers", type=int, default=1)
        parser.add_argument("--steps", type=int, default=25)
        parser.add_argument("--guidance-scale", type=float, default=3.5)
        parser.add_argument("--height", type=int, default=1024)
        parser.add_argument("--width", type=int, default=1024)
        args = parser.parse_args()
        for k,v in vars(args).items():
            setattr(self, k.replace('-', '_'), v)
        if self.output_dir is None:
            self.output_dir = os.path.join(self.model, "results")
        os.makedirs(self.output_dir, exist_ok=True)

# Initialize config
config = Config()

# Load prompts from file
prompts = []
with open(config.input_file, 'r') as f:
    for line in f:
        data = json.loads(line)
        if "enhanced" in data:
            enhanced = data["enhanced"]
            if isinstance(enhanced, list):
                enhanced = enhanced[0]
            prompts.append(enhanced)
            continue
        if "prompt" in data:
            prompts.append(data["prompt"])
total_prompts = len(prompts)
# Determine this worker's share of prompts
if config.num_workers > 1:
    prompts_to_process = [(i, p) for i, p in enumerate(prompts) if i % config.num_workers == config.worker_index]
else:
    prompts_to_process = list(enumerate(prompts))
print(f"Worker {config.worker_index}: processing {len(prompts_to_process)}/{total_prompts} prompts...")

# Load model pipeline
if config.model.lower() == "cogview4":
    pipe = CogView4Pipeline.from_pretrained("./CogView4-6B", torch_dtype=torch.bfloat16)
    pipe.enable_model_cpu_offload()
    pipe.vae.enable_slicing()
    pipe.vae.enable_tiling()
    pipe.enable_model_cpu_offload() 
elif config.model.lower() == "hidream":
    tok = PreTrainedTokenizerFast.from_pretrained("./Meta-Llama-3.1-8B-Instruct")
    llm = LlamaForCausalLM.from_pretrained("./Meta-Llama-3.1-8B-Instruct",
                                                     torch_dtype=torch.bfloat16,
                                                     output_hidden_states=True,
                                                     output_attentions=True)
    transformer = DiTTransformer2DModel.from_pretrained(
        "./HiDream-Il-nf4", subfolder="transformer", torch_dtype=torch.float16
    )
    pipe = HiDreamImagePipeline.from_pretrained(
            "./HiDream-Il-nf4",
            tokenizer=tok, tokenizer_2=tok, tokenizer_3=tok, tokenizer_4=tok,
            text_encoder=llm, text_encoder_2=llm,
            text_encoder_3=llm, text_encoder_4=llm,
            transformer=transformer, torch_dtype=torch.float16)
    pipe.enable_model_cpu_offload() 
elif config.model.lower() == "sdxl":
    pipe = StableDiffusionXLPipeline.from_single_file(
        "./sdxl-base-1.0/sd_xl_base_1.0.safetensors",   # or "stabilityai/stable-diffusion-xl-base-1.0"
        torch_dtype=torch.float16,
        safety_checker=None,
    )
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    pipe.to("cuda") # for sdxl, use always cuda
    
elif config.model.lower()  in {"sd14", "stable-diffusion-1.4", "sd15", "stable-diffusion-1.5"}:
    repo = ("./v1-5-pruned-emaonly.safetensors"
            if "15" in config.model.lower() else
            "./sd-v1-4.safetensors")

    pipe = StableDiffusionPipeline.from_single_file(
        repo,
        safety_checker=None,
        torch_dtype=torch.float16,
    )
    # Replace scheduler for faster, better quality
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    pipe.to("cuda") # for sd14/15, use always cuda
else:
    raise ValueError("Unknown model specified.")


results_log = []
for idx, prompt in tqdm(prompts_to_process, total=len(prompts_to_process), desc="Generating images", unit="prompt"):
    seed = random.randint(0, 2**32 - 1)
    generator = torch.Generator("cuda").manual_seed(seed)
    if config.model.lower() == "cogview4":
        image = pipe(prompt=prompt,
                     height=config.height, width=config.width,
                     num_inference_steps=config.steps,
                     guidance_scale=config.guidance_scale,
                     num_images_per_prompt=1).images[0]
    elif config.model.lower() == "hidream":
        image = pipe(prompt,
                     height=config.height, width=config.width,
                     num_inference_steps=config.steps,
                     guidance_scale=config.guidance_scale,
                     num_images_per_prompt=1).images[0]
    elif config.model.lower() == "sd14" or config.model.lower() == "sd15":
        image = pipe(prompt,
                     height=512, width=512,
                     num_inference_steps=config.steps,
                     guidance_scale=config.guidance_scale,
                     negative_prompt="worst quality, bad anatomy, ugly, blurry, lowres, bad quality",
                     num_images_per_prompt=1).images[0]
    elif config.model.lower() == "sdxl":
        image = pipe(
            prompt,
            height=config.height, width=config.width,        # SD XL defaults to 1024×1024
            num_inference_steps=config.steps,
            guidance_scale=config.guidance_scale,
            negative_prompt=(
                "low quality, jpeg artifacts, watermark, text, signature, blurry, "
                "oversaturated, mutated fingers, bad anatomy"
            ),
            num_images_per_prompt=1,
        ).images[0]
    else:
        raise ValueError("Unknown model specified.")
    random_id = str(uuid4())
    filename = f"{config.model}_{random_id}_{idx:04d}.png"
    filepath = os.path.join(config.output_dir,config.model.lower(), filename)
    os.makedirs(os.path.dirname(filepath), exist_ok=True)
    image.save(filepath)
    results_log.append({
        "prompt": prompt,
        "output_file": filepath,
        "steps": config.steps,
        "guidance_scale": config.guidance_scale,
        "seed": seed,
        "model": config.model,
    })
    # save to .txt too
    with open(filepath.replace(".png", ".json"), "w") as f:
        json.dump({
            "prompt": prompt,
            "output_file": filepath,
            "steps": config.steps,
            "guidance_scale": config.guidance_scale,
            "seed": seed,
            "model": config.model,
        }, f, indent=4)


log_path = os.path.join(config.output_dir, "generation_log.json")
with open(log_path, "w") as logf:
    json.dump(results_log, logf, indent=2)
print(f"Done! Generated {len(results_log)} images. Log saved to {log_path}")
