import os
import sys
import torch
from transformers import pipeline
from diffusers import BitsAndBytesConfig, SD3Transformer2DModel
from diffusers import StableDiffusion3Pipeline
import random
import uuid
import time
import logging
import threading
from filelock import FileLock
import json  
from PIL import Image
import os
from tqdm import tqdm

sd_model_id = "stabilityai/stable-diffusion-3.5-large"  # Keep original model ID

def initialize_sd_pipeline():
    nf4_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
    # Load the quantized transformer model
    model_nf4 = SD3Transformer2DModel.from_pretrained(
        sd_model_id,
        subfolder="transformer",
        quantization_config=nf4_config,
        torch_dtype=torch.bfloat16
    )
    # Initialize the Stable Diffusion 3.5 pipeline with the quantized transformer
    sd_pipeline = StableDiffusion3Pipeline.from_pretrained(
        sd_model_id,
        transformer=model_nf4,
        torch_dtype=torch.bfloat16
    )
    sd_pipeline.enable_model_cpu_offload()
    return sd_pipeline

def get_random_dimensions(min_ratio=0.5, max_ratio=2.0, seed_key=None):
    seed = int.from_bytes(uuid.uuid5(uuid.NAMESPACE_DNS, seed_key).bytes, 'big') if seed_key else None # Generate seed from key
    random.seed(seed)
    max_pixel = 1024 * 1024
    ratio = random.triangular(min_ratio, max_ratio)  # Random ratio between min and max
    width = int((max_pixel * ratio) ** 0.5)
    height = int(width / ratio)
    # Ensure width and height are multiples of 32
    width = width // 32 * 32
    height = height // 32 * 32
    return width, height

sd_pipeline = initialize_sd_pipeline()

def yield_keys_prompts(jsonl_file, use_original=False):
    with open(jsonl_file, 'r') as f:
        skip_counter = 0
        for line in f:
            skip_counter += 1
            if skip_counter < 4000:
                continue
            data = json.loads(line)
            #print(data.keys())
            entry = data['entry']
            #print(entry.keys())
            key = entry['index']
            if "result" in data:
                prompts = [data['result']]
                #print(prompts)
            else:
                results = [data['result1'], data['result2']] if 'result2' in entry else [data['result1']]
                prompts = []
                for result in results:
                    prompt = result['extended'] if result.get('extended', None) else result['generated']
                    prompts.append(prompt)
            if use_original:
                prompts = [entry["caption"]]
            print(f"Key: {key}, Prompt: {prompts}")
            yield key, prompts

def get_folder_path(root_dir, key, num_folders=200):
    uuid_int = int.from_bytes(uuid.uuid5(uuid.NAMESPACE_DNS, key).bytes, 'big')  # Generate seed from key (deteministic)
    folder_index = uuid_int % num_folders
    folder_path = os.path.join(root_dir, f"folder_{folder_index+1}")
    os.makedirs(folder_path, exist_ok=True)  # Ensure the folder exists
    return folder_path

def should_skip(key, num_workers, worker_index):
    key_int = int.from_bytes(uuid.uuid5(uuid.NAMESPACE_DNS, key).bytes, 'big')
    return key_int % num_workers != worker_index

def generate_image(sd_pipeline, prompt, width, height, folder_path, output_uuid):
    if os.path.exists(os.path.join(folder_path, f"{output_uuid}_image.webp")):
        logging.info(f"Image '{output_uuid}_image.webp' already exists, skipping")
        return
    try:
        image = sd_pipeline(
            prompt=prompt,
            num_inference_steps=28,
            guidance_scale=6.5,
            max_sequence_length=512,
            width=width,
            height=height,
        ).images[0]
        os.makedirs(folder_path, exist_ok=True)
        # Save the image
        image_filename = os.path.join(folder_path, f"{output_uuid}_image.webp")
        image.save(image_filename, format='WEBP')
        logging.info(f"Saved image to '{image_filename}'")
    except Exception as e:
        logging.error(f"Error generating image: {e}")

def generate_images(sd_pipeline, jsonl_file, root_dir, num_folders, num_workers, worker_index, use_original=False):
    with tqdm(total=0, position=worker_index, desc=f"Worker {worker_index}") as pbar:
        for key, prompts in yield_keys_prompts(jsonl_file, use_original):
            if should_skip(key, num_workers, worker_index):
                logging.info(f"Skipping key '{key}'")
                continue
            for prompt in prompts:
                width, height = get_random_dimensions(seed_key=key)
                folder_path = get_folder_path(root_dir, key, num_folders)
                output_uuid = uuid.uuid5(uuid.NAMESPACE_DNS, f"{key}_{prompt}").hex
                generate_image(sd_pipeline, prompt, width, height, folder_path, output_uuid)
                pbar.update(1)

if __name__ == '__main__':
    # Configure Logging
    os.makedirs('logs', exist_ok=True)
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] [Worker] %(message)s',
        handlers=[
            logging.FileHandler("logs/worker.log"),
            logging.StreamHandler()
        ]
    )

    # Load arguments
    jsonl_file = sys.argv[1]
    root_dir = sys.argv[2]
    num_folders = int(sys.argv[3])
    num_workers = int(sys.argv[4])
    worker_index = int(sys.argv[5])
    use_original = len(sys.argv) > 6 and sys.argv[6] == "original"

    logging.info(f"Worker started with arguments: jsonl_file='{jsonl_file}', root_dir='{root_dir}', num_folders={num_folders}, num_workers={num_workers}, worker_index={worker_index}")

    generate_images(sd_pipeline, jsonl_file, root_dir, num_folders, num_workers, worker_index, use_original)
