#!/usr/bin/env python3
"""
航空目标批量推理脚本
从 inference_datas 文件夹批量推理数据
"""

import os
import sys
import yaml
import torch
from pathlib import Path
from PIL import Image
import numpy as np
import argparse

# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))

from omini.pipeline.flux_omini import Condition, generate
from omini.train_flux.trainer import OminiModel


def load_config(config_path: str):
    """加载配置文件"""
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config


def scan_inference_data(data_dir: str):
    """
    扫描 inference_datas 文件夹，找到所有可推理的样本
    
    数据结构:
    data_dir/
    ├── backgrounds/
    │   └── {id}_erased.png
    ├── masks/
    │   └── {id}_mask.png
    └── subjects/ (可选)
        └── {id}_crop.png
    
    Returns:
        List of dicts: [{"id": "xxx", "background": path, "mask": path, "subject": path}, ...]
    """
    data_dir = Path(data_dir)
    backgrounds_dir = data_dir / "backgrounds"
    masks_dir = data_dir / "masks"
    subjects_dir = data_dir / "subjects"
    
    samples = []
    
    if not backgrounds_dir.exists() or not masks_dir.exists():
        print(f"❌ Error: backgrounds/ or masks/ folder not found in {data_dir}")
        return samples
    
    # 扫描所有 background 文件
    for bg_file in sorted(backgrounds_dir.glob("*_erased.png")):
        # 提取 ID（去掉 _erased.png 后缀）
        sample_id = bg_file.stem.replace("_erased", "")
        
        # 查找对应的 mask
        mask_file = masks_dir / f"{sample_id}_mask.png"
        if not mask_file.exists():
            print(f"⚠️  Warning: Mask not found for {sample_id}, skipping")
            continue
        
        # 查找对应的 subject（可选）
        subject_file = None
        if subjects_dir.exists():
            subject_file = subjects_dir / f"{sample_id}_crop.png"
            if not subject_file.exists():
                subject_file = None
        
        samples.append({
            "id": sample_id,
            "background": str(bg_file),
            "mask": str(mask_file),
            "subject": str(subject_file) if subject_file else None,
        })
    
    return samples


def create_default_subject(size):
    """创建默认的 subject 图像（灰色方块）"""
    return Image.new("RGB", size, (128, 128, 128))


def load_rgba_with_black_background(image_path: str) -> Image.Image:
    """
    加载 RGBA 图像，使用 alpha 通道去除背景，将透明区域替换白色
    （与训练脚本中的处理方式一致）
    
    Args:
        image_path: 图像路径
    
    Returns:
        RGB 图像，透明区域为黑色背景
    """
    # 加载图像
    img = Image.open(image_path)
    
    # 如果已经是 RGB，直接返回
    if img.mode == 'RGB':
        return img
    
    # 如果是 RGBA，使用 alpha 通道处理背景
    if img.mode == 'RGBA':
        # 创建黑色背景
        background = Image.new('RGB', img.size, (255, 255, 255))
        # 使用 alpha 通道作为 mask 合成
        background.paste(img, mask=img.split()[3])  # 第4个通道是 alpha
        return background
    
    # 其他模式，转换为 RGB
    return img.convert('RGB')


@torch.no_grad()
def batch_inference(
    model,
    samples,
    output_dir: str = "inference_results/batch",
    target_size=(512, 512),
    condition_size=(512, 512),
    unified_subject_path: str = None,
    seed: int = 42
):
    """
    批量推理
    
    Args:
        model: 训练好的模型
        samples: 样本列表
        output_dir: 输出目录
        target_size: 目标图像尺寸
        condition_size: 条件图像尺寸
        unified_subject_path: 统一的 subject 图像路径
        seed: 随机种子
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # 获取三个 adapter
    subject_adapter = model.adapter_names[2]
    fill_adapter = model.adapter_names[3]
    background_adapter = model.adapter_names[4]
    
    # 统一的 prompt
    unified_prompt = "Place an aircraft at the specified position"
    
    print(f"\n{'='*70}")
    print(f"Batch Inference on {len(samples)} samples")
    print(f"{'='*70}")
    print(f"Output directory: {output_dir}")
    print(f"Adapters: {subject_adapter}, {fill_adapter}, {background_adapter}")
    print(f"Target size: {target_size}")
    print(f"Condition size: {condition_size}")
    print(f"Unified subject: {unified_subject_path}")
    print(f"Unified prompt: {unified_prompt}")
    
    # 加载统一的 subject 图像（使用 alpha 通道处理，与训练脚本一致）
    unified_subject_img = None
    if unified_subject_path and os.path.exists(unified_subject_path):
        try:
            # 使用与训练脚本相同的方式加载（处理 RGBA alpha 通道）
            unified_subject_img = load_rgba_with_black_background(unified_subject_path)
            unified_subject_img = unified_subject_img.resize(condition_size)
            print(f"✓ Loaded unified subject image from {unified_subject_path}")
            print(f"  Image mode: {Image.open(unified_subject_path).mode}, processed to RGB with black background")
        except Exception as e:
            print(f"⚠️  Failed to load unified subject: {e}")
            print(f"   Will use default gray subject")
            unified_subject_img = None
    
    for idx, sample in enumerate(samples):
        sample_id = sample["id"]
        print(f"\n[{idx+1}/{len(samples)}] Processing {sample_id}...")
        
        # 1. 加载图像
        try:
            # 加载 background 并转换为 RGB（与训练脚本一致）
            background_img = Image.open(sample["background"]).convert("RGB").resize(condition_size)
            
            # 加载 mask 并 resize
            mask_img = Image.open(sample["mask"]).convert("L").resize(condition_size)
            
            # Subject 图像处理
            if sample["subject"] is not None:
                # 优先使用样本自己的 subject（使用 RGBA 处理，与训练脚本一致）
                subject_img = load_rgba_with_black_background(sample["subject"])
                subject_img = subject_img.resize(condition_size)
            elif unified_subject_img is not None:
                # 其次使用统一的 subject
                subject_img = unified_subject_img
            else:
                # 最后使用默认灰色 subject
                subject_img = create_default_subject(condition_size)
            
        except Exception as e:
            print(f"  ❌ Failed to load images: {e}")
            continue
        
        # 2. 生成 position mask（白色区域表示目标位置）
        mask_array = np.array(mask_img)
        binary_mask = (mask_array > 128).astype(np.uint8) * 255
        position_img = Image.fromarray(binary_mask, mode="L").convert("RGB")
        
        # 统计信息
        bg_array = np.array(background_img)
        subject_array = np.array(subject_img)
        mask_ratio = (mask_array > 128).sum() / mask_array.size
        print(f"  Background stats: mean={bg_array.mean():.2f}, max={bg_array.max()}, size={background_img.size}")
        print(f"  Subject stats: mean={subject_array.mean():.2f}, max={subject_array.max()}, size={subject_img.size}")
        print(f"  Mask coverage: {mask_ratio*100:.1f}%, size={mask_img.size}")
        print(f"  Position mask size: {position_img.size}")
        
        # 3. 创建条件
        subject_condition = Condition(
            subject_img, 
            subject_adapter, 
            [-16, -32]
        )
        position_condition = Condition(
            position_img, 
            fill_adapter, 
            [0, 0]
        )
        background_condition = Condition(
            background_img, 
            background_adapter, 
            [16, -32]
        )
        
        # 4. 生成图像
        generator = torch.Generator(device=model.device)
        generator.manual_seed(seed + idx)
        
        print(f"  Generating image...")
        try:
            res = generate(
                model.flux_pipe,
                prompt=unified_prompt,
                conditions=[subject_condition, position_condition, background_condition],
                height=target_size[1],
                width=target_size[0],
                num_inference_steps=28,
                guidance_scale=3.5,
                generator=generator,
                model_config=model.model_config,
                kv_cache=model.model_config.get("independent_condition", False),
            )
            
            # 5. 保存结果
            output_path = os.path.join(output_dir, f"{sample_id}_generated.jpg")
            res.images[0].save(output_path)
            print(f"  ✓ Saved to {output_path}")
            
            # 6. 保存条件图像（用于调试）
            condition_dir = os.path.join(output_dir, "conditions")
            os.makedirs(condition_dir, exist_ok=True)
            
            subject_img.save(os.path.join(condition_dir, f"{sample_id}_subject.jpg"))
            position_img.save(os.path.join(condition_dir, f"{sample_id}_position.jpg"))
            background_img.save(os.path.join(condition_dir, f"{sample_id}_background.jpg"))
            
        except Exception as e:
            print(f"  ❌ Generation failed: {e}")
            continue
    
    print(f"\n{'='*70}")
    print(f"✓ Batch inference completed! Results saved to {output_dir}")
    print(f"{'='*70}")


def main():
    parser = argparse.ArgumentParser(description="Batch inference for aircraft dataset")
    parser.add_argument("--data_dir", type=str, default="./inference_datas",
                        help="Directory containing inference data")
    parser.add_argument("--config", type=str, default="./train/config/aircraft_single_folder.yaml",
                        help="Config file path")
    parser.add_argument("--checkpoint", type=str, default=None,
                        help="Checkpoint path (if not specified, will use latest)")
    parser.add_argument("--output_dir", type=str, default="inference_results2/batch",
                        help="Output directory")
    parser.add_argument("--unified_subject", type=str, default="./inference_datas2/train_660_0018_obj000.png",
                        help="Path to unified subject image (default: ./inference_datas2/train_660_0018_obj000.png)")
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed")
    args = parser.parse_args()
    
    # 检查数据目录
    if not os.path.exists(args.data_dir):
        print(f"❌ Error: Data directory not found: {args.data_dir}")
        return
    
    # 检查配置文件
    if not os.path.exists(args.config):
        print(f"❌ Error: Config file not found: {args.config}")
        return
    
    # 扫描数据
    print(f"Scanning data directory: {args.data_dir}")
    samples = scan_inference_data(args.data_dir)
    
    if len(samples) == 0:
        print("❌ No samples found!")
        print("Please make sure your data directory has the following structure:")
        print("  inference_datas/")
        print("  ├── backgrounds/")
        print("  │   └── xxx_erased.png")
        print("  ├── masks/")
        print("  │   └── xxx_mask.png")
        print("  └── subjects/ (optional)")
        print("      └── xxx_crop.png")
        return
    
    print(f"✓ Found {len(samples)} samples")
    
    # 加载配置
    print(f"\nLoading config from {args.config}...")
    config = load_config(args.config)
    training_config = config["train"]
    dataset_config = training_config["dataset"]
    
    target_size = tuple(dataset_config["target_size"])
    condition_size = tuple(dataset_config["condition_size"])
    
    # 查找 checkpoint
    print("\nLoading model...")
    checkpoint_path = args.checkpoint or training_config.get("resume_from_checkpoint", None)
    # checkpoint_path = "runs/20251126-230958/ckpt/5000"
    checkpoint_path = "runs_mask_weighted/20251128-004936/ckpt/16800"
    # checkpoint_path = "runs/20251127-153412/ckpt/15600"
    if checkpoint_path is None or not os.path.exists(checkpoint_path):
        save_path = training_config.get("save_path", "runs")
        if os.path.exists(save_path):
            run_dirs = [d for d in os.listdir(save_path) if os.path.isdir(os.path.join(save_path, d))]
            if run_dirs:
                latest_run = sorted(run_dirs)[-1]
                ckpt_dir = os.path.join(save_path, latest_run, "ckpt")
                if os.path.exists(ckpt_dir):
                    checkpoints = [d for d in os.listdir(ckpt_dir) if os.path.isdir(os.path.join(ckpt_dir, d))]
                    if checkpoints:
                        checkpoints = sorted(checkpoints, key=lambda x: int(x) if x.isdigit() else 0)
                        latest_ckpt = checkpoints[-1]
                        checkpoint_path = os.path.join(ckpt_dir, latest_ckpt)
                        print(f"  ✓ Found latest checkpoint: {checkpoint_path}")
    
    if checkpoint_path is None:
        print("\n❌ Error: No checkpoint found!")
        print("Please train the model first or specify --checkpoint")
        return
    
    # 加载模型
    model = OminiModel(
        flux_pipe_id=config["flux_path"],
        lora_path=checkpoint_path,
        lora_config=training_config.get("lora_config", None),
        device="cuda",
        dtype=torch.bfloat16 if config["dtype"] == "bfloat16" else torch.float32,
        model_config=config.get("model", {}),
        adapter_names=[None, None, "subject", "fill", "background"],
        optimizer_config=None,
        gradient_checkpointing=False,
    )
    model.training_config = training_config
    
    # 🔧 关键修复：显式激活所有 LoRA adapters
    print("\n  Activating LoRA adapters...")
    adapter_list = list(model.adapter_set)
    if adapter_list:
        try:
            # 方法1：使用 set_adapters 激活所有 adapter
            model.transformer.set_adapters(adapter_list)
            print(f"    ✓ Activated adapters: {adapter_list}")
        except Exception as e:
            print(f"    ⚠️  set_adapters failed: {e}")
            # 方法2：逐个启用 adapter
            for adapter_name in adapter_list:
                try:
                    model.transformer.enable_adapters(adapter_name)
                    print(f"    ✓ Enabled adapter: {adapter_name}")
                except Exception as e2:
                    print(f"    ⚠️  Failed to enable {adapter_name}: {e2}")
        
        # 验证 adapter 状态
        print("    Verifying adapter activation...")
        for name, module in model.transformer.named_modules():
            if hasattr(module, 'active_adapters'):
                if module.active_adapters:
                    print(f"      Active adapters in {name[:50]}...: {module.active_adapters}")
                    break
    
    # 设置为评估模式
    model.eval()
    model.transformer.eval()
    model.flux_pipe.vae.eval()
    model.flux_pipe.text_encoder.eval()
    model.flux_pipe.text_encoder_2.eval()
    
    # 禁用所有 dropout
    for module in model.modules():
        if isinstance(module, torch.nn.Dropout):
            module.p = 0.0
    
    print("  ✓ Model loaded and set to eval mode (with adapter activation)")
    
    # 批量推理
    batch_inference(
        model=model,
        samples=samples,
        output_dir=args.output_dir,
        target_size=target_size,
        condition_size=condition_size,
        unified_subject_path=args.unified_subject,
        seed=args.seed
    )


if __name__ == "__main__":
    main()
