#!/usr/bin/env python3
"""
航空目标批量推理脚本 (Hybrid: Solar Optimization + Frequency-Gated LoRA)
结合了光照优化 (Solar) 和频率分离 (Frequency) 的推理逻辑
"""

import os
import sys
import yaml
import torch
import random
from pathlib import Path
from PIL import Image
import numpy as np
import argparse
import torch.nn.functional as F

# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))

from omini.pipeline.flux_omini_solar import Condition, generate, seed_everything
# 使用 OminiSolarModel 作为基础，因为它包含了 Solar Encoder/Projectors 的加载逻辑
from omini.train_flux.train_aircraft_bg_crop_solar import OminiSolarModel
import omini.pipeline.flux_omini_solar as flux_omini_solar

def load_config(config_path: str):
    """加载配置文件"""
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config


def frequency_split_latents(latents, kernel_size=5):
    """
    Decompose latents into Low-Frequency and High-Frequency components.
    Same logic as in training script.
    """
    B, L, C = latents.shape
    H = int(L**0.5)
    W = H
    
    # Reshape to spatial: [B, C, H, W]
    x = latents.transpose(1, 2).view(B, C, H, W)
    
    # Low Pass
    pad = kernel_size // 2
    low = F.avg_pool2d(x, kernel_size=kernel_size, stride=1, padding=pad, count_include_pad=False)
    
    # High Pass
    high = x - low
    
    # Flatten back to [B, L, C]
    low_flat = low.flatten(2).transpose(1, 2)
    high_flat = high.flatten(2).transpose(1, 2)
    
    return low_flat, high_flat


class LowFreqCondition(Condition):
    """Low Frequency Condition (Style/Lighting)"""
    def encode(self, pipe, empty=False):
        latents, ids = super().encode(pipe, empty)
        low, _ = frequency_split_latents(latents)
        return low, ids


class HighFreqCondition(Condition):
    """High Frequency Condition (Structure/Edges)"""
    def encode(self, pipe, empty=False):
        latents, ids = super().encode(pipe, empty)
        _, high = frequency_split_latents(latents)
        return high, ids


def scan_inference_data(data_dir: str):
    """
    扫描 inference_datas 文件夹
    结构:
    data_dir/
      ├── backgrounds/ (*_erased.png)
      ├── subjects/    (*_crop.png) [可选]
      └── masks/       (*_mask.png) [必须，用于 Solar Encoder]
    """
    data_dir = Path(data_dir)
    backgrounds_dir = data_dir / "backgrounds"
    subjects_dir = data_dir / "subjects"
    masks_dir = data_dir / "masks"
    
    samples = []
    
    if not backgrounds_dir.exists():
        print(f"❌ Error: backgrounds/ folder not found in {data_dir}")
        return samples
        
    if not masks_dir.exists():
        print(f"❌ Error: masks/ folder not found in {data_dir} (Required for Solar Optimization)")
        return samples
    
    # 扫描所有 background 文件
    for bg_file in sorted(backgrounds_dir.glob("*_erased.png")):
        # 提取 ID（去掉 _erased.png 后缀）
        sample_id = bg_file.stem.replace("_erased", "")
        
        # 查找对应的 mask (必须)
        mask_file = masks_dir / f"{sample_id}_mask.png"
        if not mask_file.exists():
            print(f"⚠️  Warning: Mask not found for {sample_id}, skipping.")
            continue
        
        # 查找对应的 subject（可选）
        subject_file = None
        if subjects_dir.exists():
            subject_file = subjects_dir / f"{sample_id}_crop.png"
            if not subject_file.exists():
                subject_file = None
        
        samples.append({
            "id": sample_id,
            "background": str(bg_file),
            "mask": str(mask_file),
            "subject": str(subject_file) if subject_file else None,
        })
    
    return samples


def create_default_subject(size):
    """创建默认的 subject 图像（灰色方块）"""
    return Image.new("RGB", size, (128, 128, 128))


def load_rgba_with_black_background(image_path: str) -> Image.Image:
    """加载 RGBA 图像，处理透明背景为黑色"""
    img = Image.open(image_path)
    if img.mode == 'RGB':
        return img
    if img.mode == 'RGBA':
        background = Image.new('RGB', img.size, (255, 255, 255))
        background.paste(img, mask=img.split()[3])
        return background
    return img.convert('RGB')


@torch.no_grad()
def batch_inference(
    model,
    samples,
    output_dir: str = "inference_results_solar_frequency/batch",
    target_size=(512, 512),
    condition_size=(512, 512),
    unified_subject_path: str = None,
    seed: int = 42
):
    os.makedirs(output_dir, exist_ok=True)
    
    # 获取 adapter names
    # Hybrid Model 结构: [None, None, subject_low, subject_high, background]
    adapter_low = "subject_low"
    adapter_high = "subject_high"
    adapter_bg = "background"
    
    unified_prompt = "Place an aircraft at the specified position"
    
    print(f"\n{'='*70}")
    print(f"Batch Inference (Hybrid: Solar + Frequency) on {len(samples)} samples")
    print(f"{'='*70}")
    print(f"Output directory: {output_dir}")
    print(f"Adapters: {adapter_low}, {adapter_high}, {adapter_bg}")
    print(f"Using Seed: {seed}")
    
    # 加载统一 subject (如果有)
    unified_subject_img = None
    if unified_subject_path and os.path.exists(unified_subject_path):
        try:
            print(unified_subject_path)
            unified_subject_img = load_rgba_with_black_background(unified_subject_path)
            unified_subject_img = unified_subject_img.resize(condition_size, Image.BILINEAR)
            print(f"✓ Loaded unified subject image from {unified_subject_path}")
        except Exception as e:
            print(f"⚠️  Failed to load unified subject: {e}")
            unified_subject_img = None
    else:
        print(f"⚠️  No unified subject provided or not found at {unified_subject_path}")
    
    for idx, sample in enumerate(samples):
        sample_id = sample["id"]
        print(f"\n[{idx+1}/{len(samples)}] Processing {sample_id}...")
        
        # 1. 准备图像
        try:
            background_img = Image.open(sample["background"]).convert("RGB").resize(condition_size, Image.BILINEAR)
            mask_img = Image.open(sample["mask"]).convert("L").resize(target_size, Image.NEAREST)
            
            if sample["subject"] is not None:
                subject_img = load_rgba_with_black_background(sample["subject"])
                subject_img = subject_img.resize(condition_size, Image.BILINEAR)
            elif unified_subject_img is not None:
                subject_img = unified_subject_img
            else:
                subject_img = create_default_subject(condition_size)
        except Exception as e:
            print(f"  ❌ Failed to load images: {e}")
            continue
            
        # 2. 构建 Conditions (Hybrid)
        # Subject Low Frequency (Style)
        cond_low = LowFreqCondition(subject_img, adapter_low, [-16, -32])
        # Subject High Frequency (Structure)
        cond_high = HighFreqCondition(subject_img, adapter_high, [-16, -32])
        # Background
        cond_bg = Condition(background_img, adapter_bg, [16, -32])
        
        # 3. Solar Optimization Parameters Calculation
        # 手动运行 Solar Encoder
        bg_latents, _ = flux_omini_solar.encode_images(model.flux_pipe, background_img)
        
        # Unpack / Reshape latents
        B, L, C = bg_latents.shape
        H_latent = int(L ** 0.5)
        W_latent = H_latent
        bg_spatial = bg_latents.transpose(1, 2).view(B, C, H_latent, W_latent).to(torch.float32)
        
        # Prepare Mask
        mask_tensor = torch.from_numpy(np.array(mask_img)).float() / 255.0
        mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(model.device) # [1, 1, H, W]
        
        # Run Solar Encoder
        context_vector = model.solar_encoder(bg_spatial, mask_tensor)
        
        # Run Projectors
        solar_params_list = []
        for proj in model.solar_projectors:
            params = proj(context_vector)
            scale, shift = params.chunk(2, dim=1)
            scale = scale.unsqueeze(1).to(model.dtype)
            shift = shift.unsqueeze(1).to(model.dtype)
            solar_params_list.append((scale, shift))
        
        # 4. 生成
        seed_everything(seed)
        generator = torch.Generator(device=model.device)
        generator.manual_seed(seed + idx)
        
        print(f"  Generating image...")
        try:
            res = generate(
                model.flux_pipe,
                prompt=unified_prompt,
                conditions=[cond_low, cond_high, cond_bg], # 传入三个 condition (Hybrid)
                height=target_size[1],
                width=target_size[0],
                num_inference_steps=28,
                guidance_scale=3.5,
                generator=generator,
                model_config=model.model_config,
                kv_cache=model.model_config.get("independent_condition", False),
                solar_params_list=solar_params_list, # 传入 Solar 参数
            )
            
            # 5. 保存结果
            output_path = os.path.join(output_dir, f"{sample_id}_generated.jpg")
            res.images[0].save(output_path)
            print(f"  ✓ Saved to {output_path}")
            
            # 保存条件图像
            condition_dir = os.path.join(output_dir, "conditions")
            os.makedirs(condition_dir, exist_ok=True)
            subject_img.save(os.path.join(condition_dir, f"{sample_id}_subject.jpg"))
            background_img.save(os.path.join(condition_dir, f"{sample_id}_background.jpg"))
            
        except Exception as e:
            print(f"  ❌ Generation failed: {e}")
            import traceback
            traceback.print_exc()
            continue

    print(f"\n{'='*70}")
    print(f"✓ Batch inference completed! Results saved to {output_dir}")


def main():
    parser = argparse.ArgumentParser(description="Batch Inference for Aircraft (Solar + Frequency)")
    parser.add_argument("--config", type=str, default="./train/config/aircraft_bg_crop.yaml", help="Path to training config")
    parser.add_argument("--checkpoint", type=str, default=None, help="Path to checkpoint directory")
    parser.add_argument("--data_dir", type=str, default="./inference_datas", help="Directory containing inference data")
    parser.add_argument("--output_dir", type=str, default="inference_results_solar_frequency/batch", help="Output directory")
    parser.add_argument("--unified_subject", type=str, default=None, help="Path to a single subject image to use for all samples")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    
    args = parser.parse_args()
    
    # 🔧 1. 设置全局确定性
    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    seed_everything(seed)
    
    print("="*70)
    print("Aircraft Batch Inference (Bg + Crop + Solar + Frequency)")
    print("🔧 Deterministic mode enabled")
    print("="*70)
    
    # 1. 加载配置
    print(f"Loading config from {args.config}...")
    if not os.path.exists(args.config):
        print(f"❌ Error: Config file not found: {args.config}")
        return
    config = load_config(args.config)
    training_config = config["train"]
    dataset_config = training_config["dataset"]
    
    target_size = tuple(dataset_config["target_size"])
    condition_size = tuple(dataset_config["condition_size"])
    
    # 2. 检查 Checkpoint
    checkpoint_path = args.checkpoint
    checkpoint_path = "runs_bg_crop/20260104-225100/ckpt/6000"
    if checkpoint_path is None:
        print("\n❌ Error: No checkpoint found!")
        print("Please provide --checkpoint path")
        return

    print(f"Loading model from {checkpoint_path}...")
    
    # 3. 初始化模型
    # 使用 OminiSolarModel 以便加载 Solar 组件
    # 设置 lora_config=None 以手动加载 Adapter
    
    model = OminiSolarModel(
        flux_pipe_id=config["flux_path"],
        lora_path=None,
        lora_config=None, 
        device="cuda",
        dtype=torch.bfloat16 if config["dtype"] == "bfloat16" else torch.float32,
        model_config=config.get("model", {}),
        adapter_names=[None, None, "subject_low", "subject_high", "background"], # 占位符
        gradient_checkpointing=False,
    )
    # 恢复 model.adapter_set
    model.adapter_set = set(["subject_low", "subject_high", "background"])
    model.training_config = training_config
    
    # 4. 加载权重
    
    # 加载 Solar Components
    solar_path = os.path.join(checkpoint_path, "solar_components.pt")
    if os.path.exists(solar_path):
        print(f"Loading Solar Components from {solar_path}")
        state = torch.load(solar_path, map_location=model.device)
        model.solar_encoder.load_state_dict(state["encoder"])
        model.solar_projectors.load_state_dict(state["projectors"])
    else:
        print(f"❌ Error: solar_components.pt not found in {checkpoint_path}")
        return

    # 加载 LoRA (subject_low, subject_high, background)
    print("Loading LoRA adapters...")
    adapters_to_load = ["subject_low", "subject_high", "background"]
    
    for adapter_name in adapters_to_load:
        lora_file = os.path.join(checkpoint_path, f"{adapter_name}.safetensors")
        if os.path.exists(lora_file):
            print(f"  Loading {adapter_name} from {lora_file}")
            model.flux_pipe.load_lora_weights(checkpoint_path, weight_name=f"{adapter_name}.safetensors", adapter_name=adapter_name)
        else:
            print(f"⚠️  Warning: {adapter_name}.safetensors not found.")
            
    # 显式激活所有 LoRA adapters
    print("\n  Activating LoRA adapters...")
    adapter_list = list(model.adapter_set)
    if adapter_list:
        try:
            model.transformer.set_adapters(adapter_list)
            print(f"    ✓ Activated adapters: {adapter_list}")
        except Exception as e:
            print(f"    ⚠️  set_adapters failed: {e}")
            for adapter_name in adapter_list:
                try:
                    model.transformer.enable_adapters(adapter_name)
                    print(f"    ✓ Enabled adapter: {adapter_name}")
                except Exception as e2:
                    print(f"    ⚠️  Failed to enable {adapter_name}: {e2}")

    # 设置为评估模式
    model.eval()
    model.transformer.eval()
    model.flux_pipe.vae.eval()
    model.flux_pipe.text_encoder.eval()
    model.flux_pipe.text_encoder_2.eval()
    
    # 禁用所有 dropout
    for module in model.modules():
        if isinstance(module, torch.nn.Dropout):
            module.p = 0.0
            
    print("  ✓ Model loaded and set to eval mode (with adapter activation)")
    
    # 5. 扫描数据
    print(f"Scanning data directory: {args.data_dir}")
    samples = scan_inference_data(args.data_dir)
    if not samples:
        print("No samples found.")
        return
    print(f"✓ Found {len(samples)} samples")
        
    # 6. 运行推理
    batch_inference(
        model, 
        samples, 
        output_dir=args.output_dir,
        target_size=target_size,
        condition_size=condition_size,
        unified_subject_path=args.unified_subject,
        seed=args.seed
    )
    print("Done!")

if __name__ == "__main__":
    main()