#!/usr/bin/env python3
"""
航空目标批量推理脚本 (Bg + Crop + Frequency-Gated LoRA)
支持频率分离 (Low/High Frequency)，结构与 Solar 推理脚本保持一致
"""

import os
import sys
import yaml
import torch
import random
from pathlib import Path
from PIL import Image
import numpy as np
import argparse
import torch.nn.functional as F

# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))

from omini.pipeline.flux_omini_solar import Condition, generate, seed_everything
from omini.train_flux.trainer_mask_weighted import OminiModel

def load_config(config_path: str):
    """加载配置文件"""
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config


def frequency_split_latents(latents, kernel_size=5):
    """
    Decompose latents into Low-Frequency and High-Frequency components.
    Same logic as in training script.
    """
    B, L, C = latents.shape
    H = int(L**0.5)
    W = H
    
    # Reshape to spatial: [B, C, H, W]
    x = latents.transpose(1, 2).view(B, C, H, W)
    
    # Low Pass
    pad = kernel_size // 2
    low = F.avg_pool2d(x, kernel_size=kernel_size, stride=1, padding=pad, count_include_pad=False)
    
    # High Pass
    high = x - low
    
    # Flatten back to [B, L, C]
    low_flat = low.flatten(2).transpose(1, 2)
    high_flat = high.flatten(2).transpose(1, 2)
    
    return low_flat, high_flat


class LowFreqCondition(Condition):
    """Low Frequency Condition (Style/Lighting)"""
    def encode(self, pipe, empty=False):
        latents, ids = super().encode(pipe, empty)
        low, _ = frequency_split_latents(latents)
        return low, ids


class HighFreqCondition(Condition):
    """High Frequency Condition (Structure/Edges)"""
    def encode(self, pipe, empty=False):
        latents, ids = super().encode(pipe, empty)
        _, high = frequency_split_latents(latents)
        return high, ids


def scan_inference_data(data_dir: str):
    """
    扫描 inference_datas 文件夹
    结构:
    data_dir/
      ├── backgrounds/ (*_erased.png)
      ├── subjects/    (*_crop.png) [可选]
    """
    data_dir = Path(data_dir)
    backgrounds_dir = data_dir / "backgrounds"
    subjects_dir = data_dir / "subjects"
    
    samples = []
    
    if not backgrounds_dir.exists():
        print(f"❌ Error: backgrounds/ folder not found in {data_dir}")
        return samples
        
    # 扫描所有 background 文件
    for bg_file in sorted(backgrounds_dir.glob("*_erased.png")):
        # 提取 ID（去掉 _erased.png 后缀）
        sample_id = bg_file.stem.replace("_erased", "")
        
        # 查找对应的 subject（可选）
        subject_file = None
        if subjects_dir.exists():
            subject_file = subjects_dir / f"{sample_id}_crop.png"
            if not subject_file.exists():
                subject_file = None
        
        samples.append({
            "id": sample_id,
            "background": str(bg_file),
            "subject": str(subject_file) if subject_file else None,
        })
    
    return samples


def create_default_subject(size):
    """创建默认的 subject 图像（灰色方块）"""
    return Image.new("RGB", size, (128, 128, 128))


def load_rgba_with_black_background(image_path: str) -> Image.Image:
    """加载 RGBA 图像，处理透明背景为黑色"""
    img = Image.open(image_path)
    if img.mode == 'RGB':
        return img
    if img.mode == 'RGBA':
        background = Image.new('RGB', img.size, (255, 255, 255))
        background.paste(img, mask=img.split()[3])
        return background
    return img.convert('RGB')


@torch.no_grad()
def batch_inference(
    model,
    samples,
    output_dir: str = "inference_results_frequency/batch",
    target_size=(512, 512),
    condition_size=(512, 512),
    unified_subject_path: str = None,
    seed: int = 42
):
    os.makedirs(output_dir, exist_ok=True)
    
    # 获取 adapter names
    # Frequency Model 结构: [None, None, subject_low, subject_high, background]
    # 我们假设 model.adapter_names 已经被正确设置，或者我们直接硬编码
    adapter_low = "subject_low"
    adapter_high = "subject_high"
    adapter_bg = "background"
    
    unified_prompt = "Place an aircraft at the specified position"
    
    print(f"\n{'='*70}")
    print(f"Batch Inference (Frequency-Gated LoRA) on {len(samples)} samples")
    print(f"{'='*70}")
    print(f"Output directory: {output_dir}")
    print(f"Adapters: {adapter_low}, {adapter_high}, {adapter_bg}")
    print(f"Using Seed: {seed}")
    
    # 加载统一 subject (如果有)
    unified_subject_img = None
    if unified_subject_path and os.path.exists(unified_subject_path):
        try:
            print(unified_subject_path)
            unified_subject_img = load_rgba_with_black_background(unified_subject_path)
            unified_subject_img = unified_subject_img.resize(condition_size, Image.BILINEAR)
            print(f"✓ Loaded unified subject image from {unified_subject_path}")
        except Exception as e:
            print(f"⚠️  Failed to load unified subject: {e}")
            unified_subject_img = None
    else:
        print(f"⚠️  No unified subject provided or not found at {unified_subject_path}")
    
    for idx, sample in enumerate(samples):
        sample_id = sample["id"]
        print(f"\n[{idx+1}/{len(samples)}] Processing {sample_id}...")
        
        # 1. 准备图像
        try:
            background_img = Image.open(sample["background"]).convert("RGB").resize(condition_size, Image.BILINEAR)
            
            if sample["subject"] is not None:
                subject_img = load_rgba_with_black_background(sample["subject"])
                subject_img = subject_img.resize(condition_size, Image.BILINEAR)
            elif unified_subject_img is not None:
                subject_img = unified_subject_img
            else:
                subject_img = create_default_subject(condition_size)
        except Exception as e:
            print(f"  ❌ Failed to load images: {e}")
            continue
            
        # 2. 构建 Conditions
        # Subject Low Frequency (Style)
        cond_low = LowFreqCondition(subject_img, adapter_low, [-16, -32])
        # Subject High Frequency (Structure)
        cond_high = HighFreqCondition(subject_img, adapter_high, [-16, -32])
        # Background
        cond_bg = Condition(background_img, adapter_bg, [16, -32])
        
        # 3. 生成
        seed_everything(seed)
        generator = torch.Generator(device=model.device)
        generator.manual_seed(seed + idx)
        
        print(f"  Generating image...")
        try:
            res = generate(
                model.flux_pipe,
                prompt=unified_prompt,
                conditions=[cond_low, cond_high, cond_bg], # 传入三个 condition
                height=target_size[1],
                width=target_size[0],
                num_inference_steps=28,
                guidance_scale=3.5,
                generator=generator,
                model_config=model.model_config,
                kv_cache=model.model_config.get("independent_condition", False),
            )
            
            # 4. 保存结果
            output_path = os.path.join(output_dir, f"{sample_id}_generated.jpg")
            res.images[0].save(output_path)
            print(f"  ✓ Saved to {output_path}")
            
            # 保存条件图像
            condition_dir = os.path.join(output_dir, "conditions")
            os.makedirs(condition_dir, exist_ok=True)
            subject_img.save(os.path.join(condition_dir, f"{sample_id}_subject.jpg"))
            background_img.save(os.path.join(condition_dir, f"{sample_id}_background.jpg"))
            
        except Exception as e:
            print(f"  ❌ Generation failed: {e}")
            import traceback
            traceback.print_exc()
            continue

    print(f"\n{'='*70}")
    print(f"✓ Batch inference completed! Results saved to {output_dir}")


def main():
    parser = argparse.ArgumentParser(description="Batch Inference for Aircraft (Frequency-Gated LoRA)")
    parser.add_argument("--config", type=str, default="./train/config/aircraft_bg_crop.yaml", help="Path to training config")
    parser.add_argument("--checkpoint", type=str, default=None, help="Path to checkpoint directory")
    parser.add_argument("--data_dir", type=str, default="./inference_datas", help="Directory containing inference data")
    parser.add_argument("--output_dir", type=str, default="inference_results_frequency/batch", help="Output directory")
    parser.add_argument("--unified_subject", type=str, default=None, help="Path to a single subject image to use for all samples")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    
    args = parser.parse_args()
    
    # 🔧 1. 设置全局确定性
    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    seed_everything(seed)
    
    print("="*70)
    print("Aircraft Batch Inference (Bg + Crop + Frequency-Gated LoRA)")
    print("🔧 Deterministic mode enabled")
    print("="*70)
    
    # 1. 加载配置
    print(f"Loading config from {args.config}...")
    if not os.path.exists(args.config):
        print(f"❌ Error: Config file not found: {args.config}")
        return
    config = load_config(args.config)
    training_config = config["train"]
    dataset_config = training_config["dataset"]
    
    target_size = tuple(dataset_config["target_size"])
    condition_size = tuple(dataset_config["condition_size"])
    
    # 2. 检查 Checkpoint
    checkpoint_path = args.checkpoint
    checkpoint_path = "runs_bg_crop/20260104-224519/ckpt/2000"
    if checkpoint_path is None:
        print("\n❌ Error: No checkpoint found!")
        print("Please provide --checkpoint path")
        return

    print(f"Loading model from {checkpoint_path}...")
    
    # 3. 初始化模型
    # 注意：使用 OminiModel (Frequency 只需要基础模型支持，逻辑在 Condition 中)
    # 我们不使用 lora_config 自动创建 adapter，而是手动加载
    
    model = OminiModel(
        flux_pipe_id=config["flux_path"],
        lora_path=None,
        lora_config=None, # 设为 None，避免在 __init__ 中创建 adapter
        device="cuda",
        dtype=torch.bfloat16 if config["dtype"] == "bfloat16" else torch.float32,
        model_config=config.get("model", {}),
        adapter_names=[None, None, "subject_low", "subject_high", "background"], # 占位符
        gradient_checkpointing=False,
    )
    # 恢复 model.adapter_set 以便后续使用
    model.adapter_set = set(["subject_low", "subject_high", "background"])
    model.training_config = training_config
    
    # 4. 加载权重
    # 加载 LoRA (subject_low, subject_high, background)
    print("Loading LoRA adapters...")
    adapters_to_load = ["subject_low", "subject_high", "background"]
    
    for adapter_name in adapters_to_load:
        lora_file = os.path.join(checkpoint_path, f"{adapter_name}.safetensors")
        if os.path.exists(lora_file):
            print(f"  Loading {adapter_name} from {lora_file}")
            model.flux_pipe.load_lora_weights(checkpoint_path, weight_name=f"{adapter_name}.safetensors", adapter_name=adapter_name)
        else:
            print(f"⚠️  Warning: {adapter_name}.safetensors not found.")
            
    # 显式激活所有 LoRA adapters
    print("\n  Activating LoRA adapters...")
    adapter_list = list(model.adapter_set)
    if adapter_list:
        try:
            model.transformer.set_adapters(adapter_list)
            print(f"    ✓ Activated adapters: {adapter_list}")
        except Exception as e:
            print(f"    ⚠️  set_adapters failed: {e}")
            for adapter_name in adapter_list:
                try:
                    model.transformer.enable_adapters(adapter_name)
                    print(f"    ✓ Enabled adapter: {adapter_name}")
                except Exception as e2:
                    print(f"    ⚠️  Failed to enable {adapter_name}: {e2}")

    # 设置为评估模式
    model.eval()
    model.transformer.eval()
    model.flux_pipe.vae.eval()
    model.flux_pipe.text_encoder.eval()
    model.flux_pipe.text_encoder_2.eval()
    
    # 禁用所有 dropout
    for module in model.modules():
        if isinstance(module, torch.nn.Dropout):
            module.p = 0.0
            
    print("  ✓ Model loaded and set to eval mode (with adapter activation)")
    
    # 5. 扫描数据
    print(f"Scanning data directory: {args.data_dir}")
    samples = scan_inference_data(args.data_dir)
    if not samples:
        print("No samples found.")
        return
    print(f"✓ Found {len(samples)} samples")
        
    # 6. 运行推理
    batch_inference(
        model, 
        samples, 
        output_dir=args.output_dir,
        target_size=target_size,
        condition_size=condition_size,
        unified_subject_path=args.unified_subject,
        seed=args.seed
    )
    print("Done!")

if __name__ == "__main__":
    main()