#!/usr/bin/env python3
"""
航空目标批量推理脚本 (Bg + Crop 版本)
取消 Mask 输入，只保留 Subject 和 Background
"""

import os
import sys
import yaml
import torch
import random
from pathlib import Path
from PIL import Image
import numpy as np
import argparse

# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))

from omini.pipeline.flux_omini import Condition, generate, seed_everything
from omini.train_flux.trainer import OminiModel


def load_config(config_path: str):
    """加载配置文件"""
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config


def scan_inference_data(data_dir: str):
    """
    扫描 inference_datas 文件夹
    即使没有 mask 也可以运行（因为不需要 mask 作为输入）
    但为了兼容现有数据结构，我们还是扫描 backgrounds 和 subjects
    """
    data_dir = Path(data_dir)
    backgrounds_dir = data_dir / "backgrounds"
    subjects_dir = data_dir / "subjects"
    
    samples = []
    
    if not backgrounds_dir.exists():
        print(f"❌ Error: backgrounds/ folder not found in {data_dir}")
        return samples
    
    # 扫描所有 background 文件
    for bg_file in sorted(backgrounds_dir.glob("*_erased.png")):
        # 提取 ID（去掉 _erased.png 后缀）
        sample_id = bg_file.stem.replace("_erased", "")
        
        # 查找对应的 subject（可选）
        subject_file = None
        if subjects_dir.exists():
            subject_file = subjects_dir / f"{sample_id}_crop.png"
            if not subject_file.exists():
                subject_file = None
        
        samples.append({
            "id": sample_id,
            "background": str(bg_file),
            "subject": str(subject_file) if subject_file else None,
        })
    
    return samples


def create_default_subject(size):
    """创建默认的 subject 图像（灰色方块）"""
    return Image.new("RGB", size, (128, 128, 128))


def load_rgba_with_black_background(image_path: str) -> Image.Image:
    """加载 RGBA 图像，处理透明背景为黑色"""
    img = Image.open(image_path)
    if img.mode == 'RGB':
        return img
    if img.mode == 'RGBA':
        background = Image.new('RGB', img.size, (255, 255, 255))
        background.paste(img, mask=img.split()[3])
        return background
    return img.convert('RGB')


@torch.no_grad()
def batch_inference(
    model,
    samples,
    output_dir: str = "inference_results_bg_crop/batch",
    target_size=(512, 512),
    condition_size=(512, 512),
    unified_subject_path: str = None,
    seed: int = 42
):
    os.makedirs(output_dir, exist_ok=True)
    
    # 获取两个 adapter (注意这里只有两个)
    # adapter_names=[None, None, "subject", "background"]
    subject_adapter = model.adapter_names[2]
    background_adapter = model.adapter_names[3]
    
    unified_prompt = "Place an aircraft at the specified position"
    
    print(f"\n{'='*70}")
    print(f"Batch Inference (Bg+Crop) on {len(samples)} samples")
    print(f"{'='*70}")
    print(f"Output directory: {output_dir}")
    print(f"Adapters: {subject_adapter}, {background_adapter}")
    
    # 加载统一的 subject 图像
    unified_subject_img = None
    if unified_subject_path and os.path.exists(unified_subject_path):
        try:
            print(unified_subject_path)
            unified_subject_img = load_rgba_with_black_background(unified_subject_path)
            unified_subject_img = unified_subject_img.resize(condition_size)
            print(f"✓ Loaded unified subject image from {unified_subject_path}")
        except Exception as e:
            print(f"⚠️  Failed to load unified subject: {e}")
            unified_subject_img = None
    else:
        print(f"⚠️  No unified subject provided or not found at {unified_subject_path}")
    
    for idx, sample in enumerate(samples):
        sample_id = sample["id"]
        print(f"\n[{idx+1}/{len(samples)}] Processing {sample_id}...")
        
        # 1. 加载图像
        try:
            background_img = Image.open(sample["background"]).convert("RGB").resize(condition_size)
            
            if sample["subject"] is not None:
                subject_img = load_rgba_with_black_background(sample["subject"])
                subject_img = subject_img.resize(condition_size)
            elif unified_subject_img is not None:
                subject_img = unified_subject_img
            else:
                subject_img = create_default_subject(condition_size)
            
        except Exception as e:
            print(f"  ❌ Failed to load images: {e}")
            continue
        
        # 2. 创建条件 (双条件)
        subject_condition = Condition(
            subject_img, 
            subject_adapter, 
            [-16, -32]
        )
        background_condition = Condition(
            background_img, 
            background_adapter, 
            [16, -32]
        )
        
        # 3. 生成图像
        generator = torch.Generator(device=model.device)
        generator.manual_seed(seed + idx)
        
        print(f"  Generating image...")
        try:
            res = generate(
                model.flux_pipe,
                prompt=unified_prompt,
                conditions=[subject_condition, background_condition], # 只有两个条件
                height=target_size[1],
                width=target_size[0],
                num_inference_steps=28,
                guidance_scale=3.5,
                generator=generator,
                model_config=model.model_config,
                kv_cache=model.model_config.get("independent_condition", False),
            )
            
            output_path = os.path.join(output_dir, f"{sample_id}_generated.jpg")
            res.images[0].save(output_path)
            print(f"  ✓ Saved to {output_path}")
            
            # 保存条件图像
            condition_dir = os.path.join(output_dir, "conditions")
            os.makedirs(condition_dir, exist_ok=True)
            subject_img.save(os.path.join(condition_dir, f"{sample_id}_subject.jpg"))
            background_img.save(os.path.join(condition_dir, f"{sample_id}_background.jpg"))
            
        except Exception as e:
            print(f"  ❌ Generation failed: {e}")
            continue
    
    print(f"\n{'='*70}")
    print(f"✓ Batch inference completed! Results saved to {output_dir}")


def main():
    parser = argparse.ArgumentParser(description="Batch inference for aircraft dataset (Bg + Crop)")
    parser.add_argument("--data_dir", type=str, default="./inference_datas",
                        help="Directory containing inference data")
    parser.add_argument("--config", type=str, default="./train/config/aircraft_bg_crop.yaml",
                        help="Config file path")
    parser.add_argument("--checkpoint", type=str, default=None,
                        help="Checkpoint path (if not specified, will use latest)")
    parser.add_argument("--output_dir", type=str, default="inference_results_bg_crop/batch",
                        help="Output directory")
    parser.add_argument("--unified_subject", type=str, default=None,
                        help="Path to unified subject image (optional)")
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed")
    args = parser.parse_args()
    
    # 🔧 1. 设置全局确定性
    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    seed_everything(seed)
    
    print("="*70)
    print("Aircraft Batch Inference (Bg + Crop)")
    print("🔧 Deterministic mode enabled")
    print("="*70)
    
    # 检查数据目录
    if not os.path.exists(args.data_dir):
        print(f"❌ Error: Data directory not found: {args.data_dir}")
        return
    
    # 检查配置文件
    if not os.path.exists(args.config):
        print(f"❌ Error: Config file not found: {args.config}")
        return
    
    # 扫描数据
    print(f"Scanning data directory: {args.data_dir}")
    samples = scan_inference_data(args.data_dir)
    
    if len(samples) == 0:
        print("❌ No samples found!")
        print("Please make sure your data directory has the following structure:")
        print("  inference_datas/")
        print("  ├── backgrounds/")
        print("  │   └── xxx_erased.png")
        print("  └── subjects/ (optional)")
        print("      └── xxx_crop.png")
        return
    
    print(f"✓ Found {len(samples)} samples")
    
    # 加载配置
    print(f"\nLoading config from {args.config}...")
    config = load_config(args.config)
    training_config = config["train"]
    dataset_config = training_config["dataset"]
    
    target_size = tuple(dataset_config["target_size"])
    condition_size = tuple(dataset_config["condition_size"])
    
    # 查找 checkpoint
    print("\nLoading model...")
    checkpoint_path = args.checkpoint or training_config.get("resume_from_checkpoint", None)
    # checkpoint_path = "runs_bg_crop/20251217-161240/ckpt/12000"
    # checkpoint_path = "runs_bg_crop/20251219-162546/ckpt/5000"
    checkpoint_path = "runs_bg_crop/20260115-164618/ckpt/12000"
    if checkpoint_path is None or not os.path.exists(checkpoint_path):
        save_path = training_config.get("save_path", "runs_bg_crop")
        wandb_name = training_config.get("wandb", {}).get("name", "aircraft_bg_crop")
        
        # 尝试查找
        if os.path.exists(save_path):
             # 优先检查 wandb_name 子目录
            ckpt_dir = os.path.join(save_path, wandb_name, "ckpt")
            if not os.path.exists(ckpt_dir):
                 # 如果找不到，尝试查找 save_path 下最新的运行目录
                run_dirs = [d for d in os.listdir(save_path) if os.path.isdir(os.path.join(save_path, d))]
                if run_dirs:
                    latest_run = sorted(run_dirs)[-1]
                    ckpt_dir = os.path.join(save_path, latest_run, "ckpt")
            
            if os.path.exists(ckpt_dir):
                checkpoints = [d for d in os.listdir(ckpt_dir) if os.path.isdir(os.path.join(ckpt_dir, d)) and d.isdigit()]
                if checkpoints:
                    latest_ckpt = max([int(c) for c in checkpoints])
                    checkpoint_path = os.path.join(ckpt_dir, str(latest_ckpt))
                    print(f"  ✓ Found latest checkpoint: {checkpoint_path}")
    
    if checkpoint_path is None:
        print("\n❌ Error: No checkpoint found!")
        print("Please train the model first or specify --checkpoint")
        return
    
    # 加载模型
    model = OminiModel(
        flux_pipe_id=config["flux_path"],
        lora_path=checkpoint_path,
        lora_config=training_config.get("lora_config", None),
        device="cuda",
        dtype=torch.bfloat16 if config["dtype"] == "bfloat16" else torch.float32,
        model_config=config.get("model", {}),
        adapter_names=[None, None, "subject", "background"], # Bg+Crop 只有两个 adapter
        optimizer_config=None,
        gradient_checkpointing=False,
    )
    model.training_config = training_config
    
    # 🔧 关键修复：显式激活所有 LoRA adapters
    print("\n  Activating LoRA adapters...")
    adapter_list = list(model.adapter_set)
    if adapter_list:
        try:
            # 方法1：使用 set_adapters 激活所有 adapter
            model.transformer.set_adapters(adapter_list)
            print(f"    ✓ Activated adapters: {adapter_list}")
        except Exception as e:
            print(f"    ⚠️  set_adapters failed: {e}")
            # 方法2：逐个启用 adapter
            for adapter_name in adapter_list:
                try:
                    model.transformer.enable_adapters(adapter_name)
                    print(f"    ✓ Enabled adapter: {adapter_name}")
                except Exception as e2:
                    print(f"    ⚠️  Failed to enable {adapter_name}: {e2}")
        
        # 验证 adapter 状态
        print("    Verifying adapter activation...")
        for name, module in model.transformer.named_modules():
            if hasattr(module, 'active_adapters'):
                if module.active_adapters:
                    print(f"      Active adapters in {name[:50]}...: {module.active_adapters}")
                    break
    
    # 设置为评估模式
    model.eval()
    model.transformer.eval()
    model.flux_pipe.vae.eval()
    model.flux_pipe.text_encoder.eval()
    model.flux_pipe.text_encoder_2.eval()
    
    # 禁用所有 dropout
    for module in model.modules():
        if isinstance(module, torch.nn.Dropout):
            module.p = 0.0
    
    print("  ✓ Model loaded and set to eval mode (with adapter activation)")
    
    # 批量推理
    batch_inference(
        model=model,
        samples=samples,
        output_dir=args.output_dir,
        target_size=target_size,
        condition_size=condition_size,
        unified_subject_path=args.unified_subject,
        seed=args.seed
    )


if __name__ == "__main__":
    main()