import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import json
import argparse
from tqdm import tqdm
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image, ImageDraw
from diffusers.pipelines import FluxPipeline

import sys
sys.path.append('/mnt/workspace/workgroup/zheliu.lzy/vision_cot/OminiControl')

from src.flux.generate import generate, seed_everything
from src.flux.condition import Condition
from src.flux.module import Inter_Controller, Spatial_Controller
from src.flux.pipeline_tools import visualize_masks
from src.utils.dataset import find_applicable_scenes 
from src.utils.dataset_eligen import json_generation 

def prepare_condition_data(condition, data_path, condition_size, device, dtype, use_depth=True):
    """Process and prepare condition data for image generation"""
    try:
        caption = condition['caption']
        # entities = condition['entities']
        entities = [entity['entity_name'] for entity in condition['ans_json']["entity_layout"]]
        eligen_entity_prompts = entities

        # Process masks
        eligen_entity_masks = []
        eligen_entity_masks_pil = []
        for entity in condition['ans_json']["entity_layout"]:
            coordinates = entity["bbox"]

            # Convert percentages to pixel coordinates
            x_min = int(condition_size * coordinates[0])
            y_min = int(condition_size * coordinates[1])
            x_max = int(condition_size * coordinates[2])
            y_max = int(condition_size * coordinates[3])

            # Create binary mask
            mask = Image.new("L", (condition_size, condition_size), 0)
            draw = ImageDraw.Draw(mask)
            draw.rectangle([x_min, y_min, x_max, y_max], fill=255)
            mask = mask.convert("RGB")
            eligen_entity_masks_pil.append(mask)

            mask = np.array(mask.resize((condition_size//8, condition_size//8)))
            mask = np.where(mask > 0, 1, 0).astype(np.uint8)
            mask_tensor = torch.from_numpy(mask).to(device=device, dtype=dtype)
            eligen_entity_masks.append(mask_tensor.unsqueeze(0))
        
        # Create final condition object
        condition_data = {
            "eligen_entity_prompts": eligen_entity_prompts,
            "eligen_entity_masks": eligen_entity_masks,
            'eligen_entity_masks_pil': eligen_entity_masks_pil,
        }
        return condition_data
    except:
        return None

def prepare_condition_data_3d(condition, data_path, condition_size, device, dtype, use_depth=True):
    """Process and prepare condition data for image generation"""
    try:
        caption = condition['caption']
        # entities = condition['entities']
        entities = [entity['entity_name'] for entity in condition['ans_json']["entity_layout"]]

        # Load and process condition images
        condition_paths = []
        for i in range(len(entities)):
            path = f'{data_path}/{caption}/render_depth_{i}.png'
            condition_paths.append(path)
        
        condition_imgs = []
        for i in range(len(entities)):
            try:
                condition_imgs.append(Image.open(condition_paths[i]).resize((condition_size, condition_size)).convert("RGB"))
            except:
                condition_imgs.append(Image.new("L", (condition_size, condition_size), 0).convert("RGB"))

        eligen_entity_prompts = entities

        # Process masks
        eligen_entity_masks = []
        eligen_entity_masks_pil = []
        for img in condition_imgs:
            # Create downsampled mask for model input
            mask = np.array(img.resize((condition_size//8, condition_size//8)))
            mask = np.where(mask > 0, 1, 0).astype(np.uint8)
            mask_tensor = torch.from_numpy(mask).to(device=device, dtype=dtype)
            eligen_entity_masks.append(mask_tensor.unsqueeze(0))
            
            # Create full resolution mask for visualization
            mask_pil = np.where(np.array(img) > 0, 1, 0).astype(np.uint8)
            eligen_entity_masks_pil.append(Image.fromarray(mask_pil*255))
        
        # Convert images to tensors and sort by depth
        condition_imgs = torch.stack([T.ToTensor()(img) for img in condition_imgs])
        
        # Create final condition object
        condition_data = {
            "condition": condition_imgs,
            "eligen_entity_prompts": eligen_entity_prompts,
            "eligen_entity_masks": eligen_entity_masks,
            'eligen_entity_masks_pil': eligen_entity_masks_pil,
        }
        return (condition_data, [0, 0])
    except:
        return None

def unified_generate(
    pipe, 
    prompt, 
    condition=None, 
    eligen_entity_prompts=None,
    eligen_entity_masks=None,
    eligen_entity_masks_pil=None,
    file_path='result.png', 
    model_config={},
    condition_size=512,
    target_size=512, 
    num_inference_steps=50,
    seed=42,
):
    seed_everything(seed)
    
    # 构建generate参数
    generate_kwargs = {
        "prompt": prompt,
        "default_lora": True,
        "num_inference_steps": num_inference_steps,
    }
    
    # 处理不同模式参数
    if condition:
        generate_kwargs["conditions"] = [condition]
        generate_kwargs["height"] = condition_size
        generate_kwargs["width"] = condition_size
    
    if eligen_entity_prompts and eligen_entity_masks:
        generate_kwargs.update({
            "eligen_entity_prompts": eligen_entity_prompts,
            "eligen_entity_masks": eligen_entity_masks,
            "height": condition_size,
            "width": condition_size,
        })
    
    # 执行生成
    res = generate(
        pipe, 
        model_config=model_config,
        **generate_kwargs
    )
    
    # 保存结果图像
    save_image = res.images[0]
    save_image.resize((target_size, target_size)).save(file_path)
    
    # 处理mask可视化
    if eligen_entity_masks_pil:
        mask_path = f"{file_path[:-4].replace('samples', 'visual')}_mask.png"
        visualize_masks(save_image, eligen_entity_masks_pil, eligen_entity_prompts, mask_path)

def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=int, default=0, help='GPU index to use')
    parser.add_argument('--total_gpus', type=int, default=1, help='Total number of GPUs used')
    parser.add_argument('--batch_size', type=int, default=1, help='Number of objects to process per batch')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--model', type=str, default='eligen_3d')
    parser.add_argument('--data_path', type=str, default='t2i_compbench/dataset/non_spatial/render')
    parser.add_argument('--json_path', type=str, default="t2i_compbench/dataset/non_spatial/json")
    parser.add_argument('--save_path', type=str, default="t2i_compbench/dataset/non_spatial")
    parser.add_argument('--ckpt_path', type=str, default="runs_new/20250405-133306_eligen_loose_reward_weight_split_eligen_loose_None")
    parser.add_argument('--ckpt', type=str, default="30000")
    args = parser.parse_args()

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Paths and parameters
    flux_path = "black-forest-labs/FLUX.1-dev"
    flux_path = "black-forest-labs/FLUX.1-schnell"
    condition_size = 512 if '3d' in args.model else 1024
    target_size = 512
    condition_type = "eligen_loose"
    data_path = args.data_path
    json_path = args.json_path
    save_path = args.save_path.replace(',','')
    os.makedirs(save_path, exist_ok=True)
    
    model_config = {
        'ckpt': args.ckpt,
        # 'ckpt': "30000",
        'ckpt_path': args.ckpt_path,
        'condition_type': "eligen_loose",
        'inter_controller_type': None,
        'eligen_depth_attn': False,
        'latent_lora': ['eligen']
    }

    if True:
        ckpt = model_config['ckpt']
        ckpt_path = model_config['ckpt_path']
        lora_path = f"{ckpt_path}/ckpt/{ckpt}/pytorch_lora_weights.safetensors"

        lora_names = ['eligen']
        # eligen_path = '/mnt/workspace/workgroup/zheliu.lzy/vision_cot/DiffSynth-Studio/models/lora/entity_control/model_bf16.safetensors'
        eligen_path = 'checkpoints/eligen.bin'
        
        # Load model
        pipe = FluxPipeline.from_pretrained(flux_path, torch_dtype=torch.bfloat16)
        print("Load Flux model successfully")

        if '3d' in args.model:
            pipe.transformer.load_lora_adapter(lora_path, adapter_name='default')
        if 'eligen' in args.model:
            state_dict = torch.load(eligen_path)
            pipe.transformer.load_lora_adapter(state_dict, prefix="transformer", adapter_name="eligen",)
            if '3d' in args.model: pipe.transformer.set_adapters(lora_names + ['default'])
            print("Load Flux lora successfully")
            active_adapters = pipe.get_active_adapters()
            print(active_adapters)

        pipe = pipe.to(device=device, dtype=torch.bfloat16)
        
        prompt_list = os.listdir(data_path)#[:100]
        prompt_list = prompt_list[args.gpu::args.total_gpus]
        print(f'test_list length: {len(prompt_list)}')
        
        os.makedirs(f"{save_path}/samples", exist_ok=True)
        os.makedirs(f"{save_path}/visual", exist_ok=True)

        seeds = [args.seed]
        for seed in seeds[:1]:
            for i, prompt in tqdm(enumerate(prompt_list), total=len(prompt_list), desc="🚀 Processing batches", unit="batch"):
                with open(f'{json_path}/{prompt}.json', 'r') as f:
                    data = json.load(f)
                file_path = f"{save_path}/samples/{prompt}_{seed}.png"
                prompt = data['caption']
                # prompt = data['new_caption']
                # if True:
                if not os.path.exists(file_path):
                    if '3d' in args.model:
                        condition = prepare_condition_data_3d(data, data_path, condition_size, pipe.device, pipe.dtype)
                        if condition is None: continue

                        condition, position_delta = condition
                        condition_ = Condition(
                            condition_type=condition_type,
                            condition=condition.resize((condition_size, condition_size)).convert("RGB") 
                                if isinstance(condition, Image.Image) else condition,
                            position_delta=position_delta,
                        )
                    elif 'eligen' in args.model:
                        condition = prepare_condition_data(data, data_path, condition_size, pipe.device, pipe.dtype)
                        if condition is None: continue

                        condition_ = None
                    else:
                        condition_ = None
                    unified_generate(
                        pipe, prompt, condition_, 
                        eligen_entity_prompts=condition['eligen_entity_prompts'] if 'eligen' in args.model else None, 
                        eligen_entity_masks=condition['eligen_entity_masks'] if 'eligen' in args.model else None, 
                        eligen_entity_masks_pil=condition['eligen_entity_masks_pil'] if 'eligen' in args.model else None, 
                        file_path=file_path, condition_size=condition_size, target_size=target_size, seed=seed,
                        model_config=model_config,
                        num_inference_steps=4 if 'schnell' in flux_path else 50
                    )
                else:
                    print(f"{file_path} exists")

if __name__ == "__main__":
    main()

