#!/usr/bin/env python3
"""
航空目标推理脚本 - Mask Weighted 版本
使用 aircraft_mask_weighted 训练的模型进行推理
"""

import os
import sys
import yaml
import torch
import random
import numpy as np
from pathlib import Path
from PIL import Image

# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))

from omini.train_flux.train_aircraft_mask_weighted import AircraftMaskWeightedDataset
from omini.pipeline.flux_omini import Condition, generate, seed_everything
from omini.train_flux.trainer import OminiModel


def load_config(config_path: str):
    """加载配置文件"""
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config


def inference_on_training_samples(
    model,
    dataset,
    num_samples: int = 10,
    output_dir: str = "inference_results",
    seed: int = 42
):
    """
    对训练集样本进行推理
    
    Args:
        model: 训练好的模型
        dataset: 数据集
        num_samples: 推理样本数量
        output_dir: 输出目录
        seed: 随机种子（用于生成器，全局确定性已在main中设置）
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # 获取配置
    condition_size = model.training_config["dataset"]["condition_size"]
    target_size = model.training_config["dataset"]["target_size"]
    
    # 获取三个 adapter
    subject_adapter = model.adapter_names[2]
    fill_adapter = model.adapter_names[3]
    background_adapter = model.adapter_names[4]
    
    print(f"\n{'='*70}")
    print(f"Inference on {num_samples} training samples (Mask Weighted)")
    print(f"{'='*70}")
    print(f"Output directory: {output_dir}")
    print(f"Adapters: {subject_adapter}, {fill_adapter}, {background_adapter}")
    print(f"Target size: {target_size}")
    print(f"Condition size: {condition_size}")
    print(f"Dataset: Using Masks2 folder")
    
    # 遍历前 N 个样本
    num_samples = min(num_samples, len(dataset))
    
    
    
    # 创建一次数据集实例用于获取 PIL 图像（优化：避免重复创建）
    # 直接复用传入的 dataset，并临时开启 return_pil_image
    pil_dataset = dataset
    original_return_pil = pil_dataset.return_pil_image
    pil_dataset.return_pil_image = True
    
    # 使用 torch.no_grad() 节省显存并加速推理
    with torch.no_grad():
        for idx in range(num_samples):
            print(f"\n[{idx+1}/{num_samples}] Processing sample {idx}...")
            
            sample = pil_dataset[idx]
            
            # 提取条件图像
            subject_img = sample["condition_0"]  # PIL Image
            position_img = sample["condition_1"]  # PIL Image
            background_img = sample["condition_2"]  # PIL Image
            prompt = sample["description"]
            
            # 调试：检查图像统计信息
            bg_array = np.array(background_img)
            print(f"  Prompt: {prompt}")
            print(f"  Background stats: mean={bg_array.mean():.2f}, max={bg_array.max()}, min={bg_array.min()}")
            print(f"  Sizes - Subject: {subject_img.size}, Background: {background_img.size}, Position: {position_img.size}")
            
            # 创建条件
            subject_condition = Condition(
                subject_img, 
                subject_adapter, 
                sample["position_delta_0"].tolist()
            )
            position_condition = Condition(
                position_img, 
                fill_adapter, 
                sample["position_delta_1"].tolist()
            )
            background_condition = Condition(
                background_img, 
                background_adapter, 
                sample["position_delta_2"].tolist()
            )
            
            # 生成图像
            # 修复：使用正确的设备
            device = model.flux_pipe.device
            generator = torch.Generator(device=device)
            generator.manual_seed(seed + idx)
            
            print(f"  Generating image...")
            res = generate(
                model.flux_pipe,
                prompt=prompt,
                conditions=[subject_condition, position_condition, background_condition],
                height=target_size[1],
                width=target_size[0],
                num_inference_steps=28,
                guidance_scale=3.5,
                generator=generator,
                model_config=model.model_config,
                kv_cache=model.model_config.get("independent_condition", False),
            )
            
            # 保存结果
            output_path = os.path.join(output_dir, f"sample_{idx:04d}.jpg")
            res.images[0].save(output_path)
            print(f"  ✓ Saved to {output_path}")
            
            # 保存条件图像
            condition_dir = os.path.join(output_dir, "conditions")
            os.makedirs(condition_dir, exist_ok=True)
            
            subject_img.save(os.path.join(condition_dir, f"sample_{idx:04d}_subject.jpg"))
            position_img.save(os.path.join(condition_dir, f"sample_{idx:04d}_position.jpg"))
            background_img.save(os.path.join(condition_dir, f"sample_{idx:04d}_background.jpg"))
            
            # 保存 ground truth
            if isinstance(sample["image"], Image.Image):
                sample["image"].save(os.path.join(condition_dir, f"sample_{idx:04d}_gt.jpg"))
            
            # 保存 mask（用于对比）
            if "target_mask" in sample and isinstance(sample["target_mask"], torch.Tensor):
                mask_array = sample["target_mask"].numpy()
                # 如果是 3D tensor (C, H, W)，取第一个通道
                if mask_array.ndim == 3:
                    mask_array = mask_array[0]  # 取第一个通道
                mask_img = Image.fromarray((mask_array * 255).astype(np.uint8))
                mask_img.save(os.path.join(condition_dir, f"sample_{idx:04d}_mask.png"))
    
    # 恢复原始状态
    pil_dataset.return_pil_image = original_return_pil

    print(f"\n{'='*70}")
    print(f"✓ Inference completed! Results saved to {output_dir}")
    print(f"{'='*70}")


def main():
    """主函数"""
    # 🔧 在程序开始时就设置全局确定性行为
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    seed_everything(seed)
    
    print("="*70)
    print("🔧 Deterministic mode enabled for reproducible inference")
    print("="*70)
    
    # 配置路径
    config_path = "./train/config/aircraft_mask_weighted.yaml"
    
    if not os.path.exists(config_path):
        print(f"❌ Error: Config file not found: {config_path}")
        return
    
    # 加载配置
    print(f"Loading config from {config_path}...")
    config = load_config(config_path)
    training_config = config["train"]
    dataset_config = training_config["dataset"]
    
    # 创建数据集
    print("\n[1/3] Loading dataset...")
    dataset = AircraftMaskWeightedDataset(
        dataset_root=dataset_config["dataset_root"],
        condition_size=tuple(dataset_config["condition_size"]),
        target_size=tuple(dataset_config["target_size"]),
        drop_text_prob=0.0,
        drop_subject_prob=0.0,
        drop_position_prob=0.0,
        drop_background_prob=0.0,
        min_mask_ratio=0.0,
        max_mask_ratio=0.0,
        background_blur_prob=0.0,
        augmentation_prob=0.0,
    )
    print(f"  ✓ Dataset loaded: {len(dataset)} samples")
    print(f"    - Using Masks2 folder (Mask-weighted version)")
    
    if len(dataset) == 0:
        print("\n❌ Error: Dataset is empty!")
        return
    
    # 查找 checkpoint
    print("\n[2/3] Loading model...")
    checkpoint_path = training_config.get("resume_from_checkpoint", None)
    # checkpoint_path = "runs_mask_weighted/20251128-004936/ckpt/16800"
    checkpoint_path = "runs_mask_weighted_normalized/20251126-231002/ckpt/5000"
    # checkpoint_path = "runs_mask_weighted/20251128-004936/ckpt/14000"
    if checkpoint_path is None or not os.path.exists(checkpoint_path):
        save_path = training_config.get("save_path", "runs_mask_weighted")
        if os.path.exists(save_path):
            run_dirs = [d for d in os.listdir(save_path) if os.path.isdir(os.path.join(save_path, d))]
            if run_dirs:
                latest_run = sorted(run_dirs)[-1]
                ckpt_dir = os.path.join(save_path, latest_run, "ckpt")
                if os.path.exists(ckpt_dir):
                    checkpoints = [d for d in os.listdir(ckpt_dir) if os.path.isdir(os.path.join(ckpt_dir, d))]
                    if checkpoints:
                        checkpoints = sorted(checkpoints, key=lambda x: int(x) if x.isdigit() else 0)
                        latest_ckpt = checkpoints[-1]
                        checkpoint_path = os.path.join(ckpt_dir, latest_ckpt)
                        print(f"  ✓ Found latest checkpoint: {checkpoint_path}")
    
    if checkpoint_path is None:
        print("\n❌ Error: No checkpoint found!")
        print("Please train the model first or specify a checkpoint path.")
        return
    
    # 加载模型
    model = OminiModel(
        flux_pipe_id=config["flux_path"],
        lora_path=checkpoint_path,
        lora_config=training_config.get("lora_config", None),
        device="cuda",
        dtype=torch.bfloat16 if config["dtype"] == "bfloat16" else torch.float32,
        model_config=config.get("model", {}),
        adapter_names=[None, None, "subject", "fill", "background"],
        optimizer_config=None,
        gradient_checkpointing=False,
    )
    model.training_config = training_config
    
    # 🔧 关键修复：显式激活所有 LoRA adapters
    print("\n  Activating LoRA adapters...")
    adapter_list = list(model.adapter_set)
    if adapter_list:
        try:
            # 方法1：使用 set_adapters 激活所有 adapter
            model.transformer.set_adapters(adapter_list)
            print(f"    ✓ Activated adapters: {adapter_list}")
        except Exception as e:
            print(f"    ⚠️  set_adapters failed: {e}")
            # 方法2：逐个启用 adapter
            for adapter_name in adapter_list:
                try:
                    model.transformer.enable_adapters(adapter_name)
                    print(f"    ✓ Enabled adapter: {adapter_name}")
                except Exception as e2:
                    print(f"    ⚠️  Failed to enable {adapter_name}: {e2}")
        
        # 验证 adapter 状态
        print("    Verifying adapter activation...")
        for name, module in model.transformer.named_modules():
            if hasattr(module, 'active_adapters'):
                if module.active_adapters:
                    print(f"      Active adapters in {name}: {module.active_adapters}")
                    break
    
    # 设置为评估模式
    model.eval()  # 设置整个模型为 eval 模式
    model.transformer.eval()  # 明确设置 transformer 为 eval 模式（双重保险）
    model.flux_pipe.vae.eval()  # 确保 VAE 也在 eval 模式
    model.flux_pipe.text_encoder.eval()
    model.flux_pipe.text_encoder_2.eval()
    
    # 🔧 关键修复：禁用所有 dropout 和随机层
    for module in model.modules():
        if isinstance(module, torch.nn.Dropout):
            module.p = 0.0
    
    print("  ✓ Model loaded and set to eval mode (with deterministic settings)")
    
    # 推理
    print("\n[3/3] Running inference...")
    output_dir = "inference_results/aircraft_mask_weighted_normalized"
    inference_on_training_samples(
        model=model,
        dataset=dataset,
        num_samples=10,
        output_dir=output_dir,
        seed=42
    )


if __name__ == "__main__":
    main()
