#!/usr/bin/env python3
"""
航空目标批量推理脚本 (Bg + Crop + Solar Channel Mix)
支持光照优化网络 (Channel Mix: Gated Non-Linear Modulation)
"""

import os
import sys
import yaml
import torch
import random
from pathlib import Path
from PIL import Image
import numpy as np
import argparse
import torch.nn.functional as F

# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))

from omini.pipeline.flux_omini_solar import Condition, generate, seed_everything
import omini.pipeline.flux_omini_solar as flux_omini_solar

# Import Custom Model and Forward Functions
from omini.train_flux.train_aircraft_bg_crop_solar_channel_mix import (
    OminiSolarChannelMixModel,
    solar_block_forward_channel_mix,
    solar_single_block_forward_channel_mix,
    solar_attn_forward_channel_mix
)

def load_config(config_path: str):
    """加载配置文件"""
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config

def scan_inference_data(data_dir: str):
    """
    扫描 inference_datas 文件夹
    结构:
    data_dir/
      ├── backgrounds/ (*_erased.png)
      ├── subjects/    (*_crop.png) [可选]
      └── masks/       (*_mask.png) [必须，用于 Solar Encoder]
    """
    data_dir = Path(data_dir)
    backgrounds_dir = data_dir / "backgrounds"
    subjects_dir = data_dir / "subjects"
    masks_dir = data_dir / "masks"
    
    samples = []
    
    if not backgrounds_dir.exists():
        print(f"❌ Error: backgrounds/ folder not found in {data_dir}")
        return samples
        
    if not masks_dir.exists():
        print(f"❌ Error: masks/ folder not found in {data_dir} (Required for Solar Optimization)")
        return samples
        
    # 扫描所有 background 文件
    for bg_file in sorted(backgrounds_dir.glob("*_erased.png")):
        # 提取 ID（去掉 _erased.png 后缀）
        sample_id = bg_file.stem.replace("_erased", "")
        
        # 查找对应的 mask
        mask_file = masks_dir / f"{sample_id}_mask.png"
        if not mask_file.exists():
            print(f"⚠️  Warning: Mask not found for {sample_id}, skipping.")
            continue
        
        # 查找对应的 subject（可选）
        subject_file = None
        if subjects_dir.exists():
            subject_file = subjects_dir / f"{sample_id}_crop.png"
            if not subject_file.exists():
                subject_file = None
        
        samples.append({
            "id": sample_id,
            "background": str(bg_file),
            "mask": str(mask_file),
            "subject": str(subject_file) if subject_file else None,
        })
    
    return samples

def create_default_subject(size):
    """创建默认的 subject 图像（灰色方块）"""
    return Image.new("RGB", size, (128, 128, 128))

def load_rgba_with_black_background(image_path: str) -> Image.Image:
    """加载 RGBA 图像，处理透明背景为黑色"""
    img = Image.open(image_path)
    if img.mode == 'RGB':
        return img
    if img.mode == 'RGBA':
        background = Image.new('RGB', img.size, (255, 255, 255))
        background.paste(img, mask=img.split()[3])
        return background
    return img.convert('RGB')

@torch.no_grad()
def batch_inference(
    model,
    samples,
    output_dir: str = "inference_results_solar_channel_mix/batch",
    target_size=(512, 512),
    condition_size=(512, 512),
    unified_subject_path: str = None,
    seed: int = 42
):
    os.makedirs(output_dir, exist_ok=True)
    
    # Adapter Names
    subject_adapter = "subject"
    background_adapter = "background"
    
    unified_prompt = "Place an aircraft at the specified position"
    
    print(f"\n{'='*70}")
    print(f"Batch Inference (Solar Channel Mix) on {len(samples)} samples")
    print(f"{'='*70}")
    print(f"Output directory: {output_dir}")
    print(f"Adapters: {subject_adapter}, {background_adapter}")
    print(f"Using Seed: {seed}")
    
    # 加载统一 subject (如果有)
    unified_subject_img = None
    if unified_subject_path and os.path.exists(unified_subject_path):
        try:
            print(unified_subject_path)
            unified_subject_img = load_rgba_with_black_background(unified_subject_path)
            unified_subject_img = unified_subject_img.resize(condition_size, Image.BILINEAR)
            print(f"✓ Loaded unified subject image from {unified_subject_path}")
        except Exception as e:
            print(f"⚠️  Failed to load unified subject: {e}")
            unified_subject_img = None
    else:
        print(f"⚠️  No unified subject provided or not found at {unified_subject_path}")
            
    for idx, sample in enumerate(samples):
        sample_id = sample["id"]
        print(f"\n[{idx+1}/{len(samples)}] Processing {sample_id}...")
        
        # 1. 准备图像
        try:
            background_img = Image.open(sample["background"]).convert("RGB").resize(condition_size, Image.BILINEAR)
            mask_img = Image.open(sample["mask"]).convert("L").resize(target_size, Image.NEAREST)
            
            if sample["subject"] is not None:
                subject_img = load_rgba_with_black_background(sample["subject"])
                subject_img = subject_img.resize(condition_size, Image.BILINEAR)
            elif unified_subject_img is not None:
                subject_img = unified_subject_img
            else:
                subject_img = create_default_subject(condition_size)
        except Exception as e:
            print(f"  ❌ Failed to load images: {e}")
            continue
            
        # 2. 构建 Conditions
        cond_subject = Condition(subject_img, subject_adapter, [-16, -32])
        cond_bg = Condition(background_img, background_adapter, [16, -32])
        
        # 3. Solar Parameters Calculation
        # Encode background to latents
        bg_latents, _ = flux_omini_solar.encode_images(model.flux_pipe, background_img)
        
        # Unpack / Reshape latents
        B, L, C = bg_latents.shape
        H_latent = int(L ** 0.5)
        W_latent = H_latent
        bg_spatial = bg_latents.transpose(1, 2).view(B, C, H_latent, W_latent).to(torch.float32)
        
        # Prepare Mask
        mask_tensor = torch.from_numpy(np.array(mask_img)).float() / 255.0
        mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(model.device)
        
        # Run Solar Encoder
        context_vector = model.solar_encoder(bg_spatial, mask_tensor)
        
        # Run Projectors (Channel Mix Logic)
        solar_params_list = []
        for proj in model.solar_projectors:
            params = proj(context_vector)
            # Split: [B, D], [B, D], [B, D]
            scale, shift, gate = params.chunk(3, dim=-1)
            
            # Add singleton dimension: [B, 1, D]
            scale = scale.unsqueeze(1).to(model.dtype)
            shift = shift.unsqueeze(1).to(model.dtype)
            gate = gate.unsqueeze(1).to(model.dtype)
            
            solar_params_list.append((scale, shift, gate))
            
        # 4. 生成
        seed_everything(seed)
        generator = torch.Generator(device=model.device)
        generator.manual_seed(seed + idx)
        
        print(f"  Generating image...")
        try:
            res = generate(
                model.flux_pipe,
                prompt=unified_prompt,
                conditions=[cond_subject, cond_bg],
                height=target_size[1],
                width=target_size[0],
                num_inference_steps=28,
                guidance_scale=3.5,
                generator=generator,
                model_config=model.model_config,
                kv_cache=model.model_config.get("independent_condition", False),
                solar_params_list=solar_params_list,
                transformer_kwargs={
                    "attn_forward": solar_attn_forward_channel_mix,
                    "block_forward": solar_block_forward_channel_mix,
                    "single_block_forward": solar_single_block_forward_channel_mix,
                }
            )
            
            # 5. 保存结果
            output_path = os.path.join(output_dir, f"{sample_id}_generated.jpg")
            res.images[0].save(output_path)
            print(f"  ✓ Saved to {output_path}")
            
            # 保存条件图像
            condition_dir = os.path.join(output_dir, "conditions")
            os.makedirs(condition_dir, exist_ok=True)
            subject_img.save(os.path.join(condition_dir, f"{sample_id}_subject.jpg"))
            background_img.save(os.path.join(condition_dir, f"{sample_id}_background.jpg"))
            
        except Exception as e:
            print(f"  ❌ Generation failed: {e}")
            import traceback
            traceback.print_exc()
            continue
            
    print(f"\n{'='*70}")
    print(f"✓ Batch inference completed! Results saved to {output_dir}")

def main():
    parser = argparse.ArgumentParser(description="Batch Inference (Solar Channel Mix)")
    parser.add_argument("--config", type=str, default="./train/config/aircraft_bg_crop.yaml", help="Path to training config")
    parser.add_argument("--checkpoint", type=str, default=None, help="Path to checkpoint directory")
    parser.add_argument("--data_dir", type=str, default="./inference_datas", help="Directory containing inference data")
    parser.add_argument("--output_dir", type=str, default="inference_results_solar_channel_mix/batch", help="Output directory")
    parser.add_argument("--unified_subject", type=str, default=None, help="Path to a single subject image to use for all samples")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    
    args = parser.parse_args()
    
    # 🔧 1. 设置全局确定性
    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    seed_everything(seed)
    
    print("="*70)
    print("Aircraft Batch Inference (Solar Channel Mix)")
    print("🔧 Deterministic mode enabled")
    print("="*70)
    
    # 1. 加载配置
    print(f"Loading config from {args.config}...")
    if not os.path.exists(args.config):
        print(f"❌ Error: Config file not found: {args.config}")
        return
    config = load_config(args.config)
    training_config = config["train"]
    dataset_config = training_config["dataset"]
    
    target_size = tuple(dataset_config["target_size"])
    condition_size = tuple(dataset_config["condition_size"])
    
    # 2. 检查 Checkpoint
    checkpoint_path = args.checkpoint
    checkpoint_path = "runs_bg_crop/20260106-173535/ckpt/10000"
    if checkpoint_path is None:
        print("\n❌ Error: No checkpoint found!")
        print("Please provide --checkpoint path")
        return

    print(f"Loading model from {checkpoint_path}...")
    
    # 3. 初始化模型
    # 注意：使用 OminiSolarChannelMixModel
    
    model = OminiSolarChannelMixModel(
        flux_pipe_id=config["flux_path"],
        lora_path=None,
        lora_config=None, # 设为 None，避免在 __init__ 中创建 adapter
        device="cuda",
        dtype=torch.bfloat16 if config["dtype"] == "bfloat16" else torch.float32,
        model_config=config.get("model", {}),
        # 占位 adapter_names, 实际会由 set_adapters 控制
        adapter_names=[None, None, "subject", "background"],
        gradient_checkpointing=False,
    )
    # 恢复 model.adapter_set 以便后续使用
    model.adapter_set = set(["subject", "background"])
    model.training_config = training_config
    
    # 4. 加载权重
    # 加载 Solar Components
    solar_path = os.path.join(checkpoint_path, "solar_components.pt")
    if os.path.exists(solar_path):
        print(f"Loading Solar Components from {solar_path}")
        state = torch.load(solar_path, map_location=model.device)
        model.solar_encoder.load_state_dict(state["encoder"])
        model.solar_projectors.load_state_dict(state["projectors"])
    else:
        print(f"❌ Error: solar_components.pt not found in {checkpoint_path}")
        return
        
    # 加载 LoRA adapters
    print("Loading LoRA adapters...")
    for adapter_name in ["subject", "background"]:
        lora_file = os.path.join(checkpoint_path, f"{adapter_name}.safetensors")
        if os.path.exists(lora_file):
            print(f"  Loading {adapter_name} from {lora_file}")
            model.flux_pipe.load_lora_weights(checkpoint_path, weight_name=f"{adapter_name}.safetensors", adapter_name=adapter_name)
        else:
            print(f"⚠️  Warning: {adapter_name}.safetensors not found.")
            
    # 显式激活所有 LoRA adapters
    print("\n  Activating LoRA adapters...")
    adapter_list = list(model.adapter_set)
    if adapter_list:
        try:
            model.transformer.set_adapters(adapter_list)
            print(f"    ✓ Activated adapters: {adapter_list}")
        except Exception as e:
            print(f"    ⚠️  set_adapters failed: {e}")
            for adapter_name in adapter_list:
                try:
                    model.transformer.enable_adapters(adapter_name)
                    print(f"    ✓ Enabled adapter: {adapter_name}")
                except Exception as e2:
                    print(f"    ⚠️  Failed to enable {adapter_name}: {e2}")

    # 设置为评估模式
    model.eval()
    model.transformer.eval()
    model.flux_pipe.vae.eval()
    model.flux_pipe.text_encoder.eval()
    model.flux_pipe.text_encoder_2.eval()
    
    # 禁用所有 dropout
    for module in model.modules():
        if isinstance(module, torch.nn.Dropout):
            module.p = 0.0
            
    print("  ✓ Model loaded and set to eval mode (with adapter activation)")
    
    # 5. 扫描数据
    print(f"Scanning data directory: {args.data_dir}")
    samples = scan_inference_data(args.data_dir)
    if not samples:
        print("No samples found.")
        return
    print(f"✓ Found {len(samples)} samples")
        
    # 6. 运行推理
    batch_inference(
        model, 
        samples, 
        output_dir=args.output_dir,
        target_size=target_size,
        condition_size=condition_size,
        unified_subject_path=args.unified_subject,
        seed=args.seed
    )
    print("Done!")

if __name__ == "__main__":
    main()