import os
import sys
import torch
from transformers import pipeline
from diffusers import BitsAndBytesConfig, SD3Transformer2DModel
from diffusers import StableDiffusion3Pipeline
import random
import uuid
import time
import logging
import threading
from filelock import FileLock
import json  # Import json module to handle JSONL files

# Parse command-line arguments
device_id = sys.argv[1]
root_dir = sys.argv[2]
stop_file = sys.argv[3]
jsonl_file = sys.argv[4]  # Add JSONL file as a command-line argument
logging_prefix = sys.argv[5] if len(sys.argv) > 5 else "worker"

device_id = int(device_id)
device = f"cuda:{device_id}"  # Use the correct device ID

# Configure Logging
os.makedirs('logs', exist_ok=True)
logging.basicConfig(
    level=logging.INFO,
    format=f'%(asctime)s [%(levelname)s] [Worker {device_id}] %(message)s',
    handlers=[
        logging.FileHandler(f"logs/{logging_prefix}_worker_{device_id}.log"),
        logging.StreamHandler()
    ]
)

# Global Variables
llama_model_id = "meta-llama/Llama-3.2-1B-Instruct"  # Keep original model ID
sd_model_id = "stabilityai/stable-diffusion-3.5-large"  # Keep original model ID
num_folders = 200

def initialize_sd_pipeline():
    nf4_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
    # Load the quantized transformer model
    model_nf4 = SD3Transformer2DModel.from_pretrained(
        sd_model_id,
        subfolder="transformer",
        quantization_config=nf4_config,
        torch_dtype=torch.bfloat16
    )
    # Initialize the Stable Diffusion 3.5 pipeline with the quantized transformer
    sd_pipeline = StableDiffusion3Pipeline.from_pretrained(
        sd_model_id,
        transformer=model_nf4,
        torch_dtype=torch.bfloat16
    )
    sd_pipeline.enable_model_cpu_offload()
    return sd_pipeline

def get_folder_path(root_dir, key, num_folders=200):
    uuid_int = int(key)
    folder_index = uuid_int % num_folders
    folder_path = os.path.join(root_dir, f"folder_{folder_index+1}")
    os.makedirs(folder_path, exist_ok=True)  # Ensure the folder exists
    return folder_path

# Get Random Dimensions
def get_random_dimensions(min_ratio=0.5, max_ratio=2.0):
    max_pixel = 1024 * 1024
    ratio = random.triangular(min_ratio, max_ratio)  # Random ratio between min and max
    width = int((max_pixel * ratio) ** 0.5)
    height = int(width / ratio)
    # Ensure width and height are multiples of 32
    width = width // 32 * 32
    height = height // 32 * 32
    return width, height

def worker():
    # Initialize pipeline
    try:
        sd_pipeline = initialize_sd_pipeline()
    except Exception as e:
        logging.error(f"Error initializing pipelines: {e}")
        return

    logging.info(f"Worker started on device {device}.")

    # Open the JSONL file
    with open(jsonl_file, 'r', encoding='utf-8') as jf:
        for line in jf:
            # Check for stop condition
            if os.path.exists(stop_file):
                logging.info("Stop file detected. Exiting.")
                break

            try:
                # Parse JSON line
                data = json.loads(line)
                entry = data.get("entry", {})
                result1 = data.get("result1", {})

                # Use "extended" item for prompt
                prompt = result1.get("extended", "").strip()
                if not prompt:
                    logging.warning("No 'extended' field found. Skipping entry.")
                    continue

                # Use "key" for filename
                key = entry.get("key", "")
                if not key:
                    logging.warning("No 'key' field found. Skipping entry.")
                    continue

                # Determine folder path based on key
                folder_path = get_folder_path(root_dir, str(key), num_folders)
                if os.path.exists(os.path.join(folder_path, f"{key}.webp")):
                    logging.info(f"Image already exists for key {key}. Skipping.")
                    continue

                # Generate image with Stable Diffusion
                width, height = get_random_dimensions()
                image = sd_pipeline(
                    prompt=prompt,
                    num_inference_steps=28,
                    guidance_scale=6.5,
                    max_sequence_length=512,
                    width=width,
                    height=height,
                ).images[0]

                # Save the image
                image_filename = os.path.join(folder_path, f"{key}.webp")
                image.save(image_filename, format='WEBP')

                logging.info(f"Image saved to {image_filename}")

            except Exception as e:
                logging.error(f"Error during generation: {e}")

            # Optional: Sleep for a short duration to prevent overloading the system
            time.sleep(1)

    # Clean up resources if necessary
    del sd_pipeline
    torch.cuda.empty_cache()

    logging.info("Worker has exited.")

if __name__ == '__main__':
    worker()
