#!/usr/bin/env python3
"""
Debug Single Inference Script (Mimicking inference_aircraft_bg_crop.py)
"""

import os
import sys
import yaml
import torch
import random
import numpy as np
from pathlib import Path
from PIL import Image
import torchvision.transforms as T
import os
# [Config] Force offline mode to use local cache
os.environ["HF_HUB_OFFLINE"] = "1"

# 添加项目根目录
project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))

from omini.pipeline.flux_omini import Condition, generate, seed_everything
from omini.train_flux.trainer import OminiModel

def load_rgba_with_black_background(image_path: str) -> Image.Image:
    """
    加载 RGBA 图像，使用 alpha 通道去除背景，将透明区域替换白色
    （与训练脚本中的处理方式一致）
    注意：虽然函数名包含 black_background，实际代码中使用的是白色 (255, 255, 255)
    """
    # 加载图像
    img = Image.open(image_path)
    
    # 如果已经是 RGB，直接返回
    if img.mode == 'RGB':
        return img
    
    # 如果是 RGBA，使用 alpha 通道处理背景
    if img.mode == 'RGBA':
        # 创建白色背景
        background = Image.new('RGB', img.size, (255, 255, 255))
        # 使用 alpha 通道作为 mask 合成
        background.paste(img, mask=img.split()[3])  # 第4个通道是 alpha
        return background
    
    # 其他模式，转换为 RGB
    return img.convert('RGB')

def load_config(config_path: str):
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config

def main():
    # 🔧 在程序开始时就设置全局确定性行为
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    seed_everything(seed)
    
    print("="*70)
    print("🔧 Deterministic mode enabled for reproducible inference")
    print("="*70)
    
    # 获取配置路径
    config_path = os.environ.get("OMINI_CONFIG", "train/config/aircraft_bg_crop_inf.yaml")
    print(f"Loading config from {config_path}")
    config = load_config(config_path)
    
    # 1. 加载模型 (完全复刻 inference_aircraft_bg_crop.py)
    training_config = config["train"]
    print("\n[2/3] Loading model...")
    
    # 查找 checkpoint
    checkpoint_path = training_config.get("resume_from_checkpoint", None)
    # Hardcode fallback for debug script if needed, but try auto-detect first
    if checkpoint_path is None:
        checkpoint_path = "runs_bg_crop/20260115-164618/ckpt/4000"

    if checkpoint_path is None or not os.path.exists(checkpoint_path):
        save_path = training_config.get("save_path", "runs_bg_crop")
        wandb_name = training_config.get("wandb", {}).get("name", "aircraft_bg_crop")
        if os.path.exists(save_path):
            ckpt_dir = os.path.join(save_path, wandb_name, "ckpt")
            if not os.path.exists(ckpt_dir):
                run_dirs = [d for d in os.listdir(save_path) if os.path.isdir(os.path.join(save_path, d))]
                if run_dirs:
                    latest_run = sorted(run_dirs)[-1]
                    ckpt_dir = os.path.join(save_path, latest_run, "ckpt")
            
            if os.path.exists(ckpt_dir):
                checkpoints = [d for d in os.listdir(ckpt_dir) if os.path.isdir(os.path.join(ckpt_dir, d)) and d.isdigit()]
                if checkpoints:
                    latest_ckpt = max([int(c) for c in checkpoints])
                    checkpoint_path = os.path.join(ckpt_dir, str(latest_ckpt))
                    print(f"  ✓ Found latest checkpoint: {checkpoint_path}")

    if checkpoint_path is None:
        print("\n❌ Error: No checkpoint found!")
        return

    # 初始化模型
    # Revert to config dtype to match working script
    dtype_to_use = torch.bfloat16 if config["dtype"] == "bfloat16" else torch.float32
    print(f"  Using dtype: {dtype_to_use}")
    
    model = OminiModel(
        flux_pipe_id=config["flux_path"],
        lora_path=checkpoint_path,
        lora_config=training_config.get("lora_config", None),
        device="cuda",
        dtype=dtype_to_use,
        model_config=config.get("model", {}),
        adapter_names=[None, None, "subject", "background"], 
    )
    model.training_config = training_config
    
    # 🔧 关键修复：显式激活所有 LoRA adapters
    print("\n  Activating LoRA adapters...")
    adapter_list = list(model.adapter_set)
    if adapter_list:
        try:
            model.transformer.set_adapters(adapter_list)
            print(f"    ✓ Activated adapters: {adapter_list}")
        except Exception as e:
            print(f"    ⚠️  set_adapters failed: {e}")
            for adapter_name in adapter_list:
                try:
                    model.transformer.enable_adapters(adapter_name)
                    print(f"    ✓ Enabled adapter: {adapter_name}")
                except Exception as e2:
                    print(f"    ⚠️  Failed to enable {adapter_name}: {e2}")
        
        # 验证 adapter 状态
        print("    Verifying adapter activation...")
        for name, module in model.transformer.named_modules():
            if hasattr(module, 'active_adapters'):
                if module.active_adapters:
                    print(f"      Active adapters in {name}: {module.active_adapters}")
                    break

    # [Diagnostics] Deep check of adapter status
    print("\n🔍 Deep Diagnostics: Checking LoRA Layer Status...")
    lora_count = 0
    active_count = 0
    for name, module in model.transformer.named_modules():
        # 检查是否是 LoRA 线性层 (通常有 lora_A, lora_B 或类似结构)
        # OminiControl 可能使用自定义的 Peft 封装或自己的 AdapterLayer
        if hasattr(module, "active_adapters"):
            lora_count += 1
            if module.active_adapters:
                active_count += 1
                if active_count <= 5: # 只打印前几个
                    print(f"  Layer {name}: Active={module.active_adapters}")
    
    print(f"  Total LoRA-capable layers: {lora_count}")
    print(f"  Active layers: {active_count}")
    
    if active_count == 0:
        print("  ⚠️ WARNING: No layers have active adapters! Subject will be ignored.")
    else:
        print("  ✓ Adapters seem active.")
    
    # 设置为评估模式
    model.eval()
    model.transformer.eval()
    model.flux_pipe.vae.eval()
    model.flux_pipe.text_encoder.eval()
    model.flux_pipe.text_encoder_2.eval()
    
    for module in model.modules():
        if isinstance(module, torch.nn.Dropout):
            module.p = 0.0
            
    print("  ✓ Model loaded and set to eval mode (with deterministic settings)")

    # 3. 运行推理 (单样本)
    print("\n[3/3] Running inference (Single Debug)...")
    
    # 硬编码路径
    subject_path = "augmented_results_batch/debug_conditions/304_A2_step0_subject.png"
    bg_path = "augmented_results_batch/debug_conditions/304_A2_step0_bg_input.png"
    output_dir = "debug_single_results"
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, "debug_single_output_mimic.jpg")
    
    if not os.path.exists(subject_path) or not os.path.exists(bg_path):
        print("❌ Error: Input images not found!")
        return

    # 使用 helper 函数加载 subject，确保处理 Alpha 通道
    subject_img = load_rgba_with_black_background(subject_path)
    background_img = Image.open(bg_path).convert("RGB")
    
    # Resize to target size (Assuming 512x512 from config or typical usage)
    target_size = (512, 512)
    subject_img = subject_img.resize(target_size)
    background_img = background_img.resize(target_size)
    
    subject_adapter = model.adapter_names[2]
    background_adapter = model.adapter_names[3]
    
    subject_condition = Condition(subject_img, subject_adapter, [-16, -32])
    background_condition = Condition(background_img, background_adapter, [16, -32])
    
    generator = torch.Generator(device=model.device)
    generator.manual_seed(seed)
    
    # Clear cache before generation
    torch.cuda.empty_cache()
    
    try:
        res = generate(
            model.flux_pipe,
            prompt="Place an aircraft at the specified position",
            conditions=[subject_condition, background_condition],
            height=target_size[1],
            width=target_size[0],
            num_inference_steps=28,
            guidance_scale=3.5,
            generator=generator,
            model_config=model.model_config,
            kv_cache=model.model_config.get("independent_condition", False), # Follow config
            inter_condition_attention=model.model_config.get("inter_condition_attention", False)
        )
        
        res.images[0].save(output_path)
        print(f"  ✓ Saved generated image to {output_path}")
        
        # Save Conditions
        subject_img.save(os.path.join(output_dir, "debug_condition_subject.jpg"))
        background_img.save(os.path.join(output_dir, "debug_condition_background.jpg"))
        print(f"  ✓ Saved conditions to {output_dir}")
        
    except Exception as e:
        print(f"  ❌ Generation failed: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()