#!/usr/bin/env python3
"""
航空目标批量推理脚本 (Bg + Crop + Solar Optimization)
支持光照优化网络，需要加载 solar_components.pt
"""

import os
import sys
import yaml
import torch
import random
from pathlib import Path
from PIL import Image
import numpy as np
import argparse

# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))

from omini.pipeline.flux_omini_solar import Condition, generate, seed_everything
from omini.pipeline.flux_omini import specify_lora, apply_rotary_emb
from omini.train_flux.train_aircraft_bg_crop_solar import OminiSolarModel
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns

# --- Visualization Globals ---
GLOBAL_LAYER_COUNTER = 0
ATTN_MAP_STORAGE = {} # {step_layer: tensor}
TOTAL_LAYERS = 57 # 19 double + 38 single for Flux.1-dev? 
# DoubleStream: 19, SingleStream: 38. Total 57.
TARGET_LAYER_INDEX = 50 # Visualize a layer in Single Stream Block (19-56) for better spatial structure
TARGET_STEP_INDEX = 20 # Visualize a step in the middle/end

def visualize_attn_forward(
    attn,
    hidden_states,
    adapters,
    hidden_states2=[],
    position_embs=None,
    group_mask=None,
    cache_mode=None,
    to_cache=None,
    cache_storage=None,
    solar_scale=None,
    solar_shift=None,
    target_size=(512, 512),
    **kwargs,
):
    global GLOBAL_LAYER_COUNTER
    
    # Determine current step and layer
    step_idx = GLOBAL_LAYER_COUNTER // TOTAL_LAYERS
    layer_idx = GLOBAL_LAYER_COUNTER % TOTAL_LAYERS
    
    # Debug print every 57 layers (once per step)
    if layer_idx == 0:
        print(f"DEBUG: visualize_attn_forward called. Step={step_idx}")

    GLOBAL_LAYER_COUNTER += 1
    
    capture_this = (layer_idx == TARGET_LAYER_INDEX) and (step_idx == TARGET_STEP_INDEX)
    
    if capture_this:
        print(f"DEBUG: Capturing attention map at Step={step_idx}, Layer={layer_idx}")

    bs, _, _ = hidden_states[0].shape
    h2_n = len(hidden_states2)

    queries, keys, values = [], [], []

    # Text branch
    for i, hidden_state in enumerate(hidden_states2):
        query = attn.add_q_proj(hidden_state)
        key = attn.add_k_proj(hidden_state)
        value = attn.add_v_proj(hidden_state)
        head_dim = key.shape[-1] // attn.heads
        reshape_fn = lambda x: x.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
        query, key, value = map(reshape_fn, (query, key, value))
        query, key = attn.norm_added_q(query), attn.norm_added_k(key)
        queries.append(query)
        keys.append(key)
        values.append(value)

    # Image branch
    for i, hidden_state in enumerate(hidden_states):
        with specify_lora((attn.to_q, attn.to_k, attn.to_v), adapters[i + h2_n]):
            query = attn.to_q(hidden_state)
            key = attn.to_k(hidden_state)
            value = attn.to_v(hidden_state)

        head_dim = key.shape[-1] // attn.heads
        reshape_fn = lambda x: x.view(bs, -1, attn.heads, head_dim).transpose(1, 2)

        query, key, value = map(reshape_fn, (query, key, value))
        query, key = attn.norm_q(query), attn.norm_k(key)
        
        # Apply Solar Modulation (Standard Scale/Shift)
        # Must be done AFTER reshape to [B, Heads, L, Head_Dim]
        if solar_scale is not None and solar_shift is not None:
            s_scale = solar_scale.view(bs, 1, attn.heads, head_dim).transpose(1, 2)
            s_shift = solar_shift.view(bs, 1, attn.heads, head_dim).transpose(1, 2)
            value = value * (1 + s_scale) + s_shift

        queries.append(query)
        keys.append(key)
        values.append(value)

    # Apply rotary embedding
    if position_embs is not None:
        queries = [apply_rotary_emb(q, position_embs[i]) for i, q in enumerate(queries)]
        keys = [apply_rotary_emb(k, position_embs[i]) for i, k in enumerate(keys)]

    # Attention Calculation & Visualization
    attn_outputs = []
    for i, query in enumerate(queries):
        L_q = query.shape[2]
        keys_, values_ = [], []
        for j, (k, v) in enumerate(zip(keys, values)):
            if (group_mask is not None) and not (group_mask[i][j].item()):
                continue
            keys_.append(k)
            values_.append(v)
            
        key_cat = torch.cat(keys_, dim=2)
        value_cat = torch.cat(values_, dim=2)
        
        # --- Visualization Logic ---
        # Capture ALL Image Queries to be safe
        if capture_this and i >= h2_n:
            branch_idx = i - h2_n
            branch_name = "Target" if branch_idx == 0 else f"Cond{branch_idx-1}"
            
            # query: [B, Heads, L_q, D]
            # key_cat: [B, Heads, S_total, D]
            
            # 1. Calculate Attention Scores
            scale = query.shape[-1] ** -0.5
            attn_scores = torch.einsum("bhld,bhsd->bhls", query, key_cat) * scale
            attn_probs = attn_scores.softmax(dim=-1)
            
            # 2. Pick Center Token
            # Note: L_q is flattened length (e.g., 1024 for 32x32)
            # Simple // 2 gives 512, which is (16, 0) -> Left Edge!
            # We need geometric center (16, 16).
            
            # Infer H, W assuming square (or aspect ratio known from target_size)
            # target_size is (Width, Height) e.g. (512, 512) -> tokens (32, 32)
            H_t = target_size[1] // 16
            W_t = target_size[0] // 16
            
            if L_q == H_t * W_t:
                center_y = H_t // 2
                center_x = W_t // 2
                center_idx = center_y * W_t + center_x
            else:
                # Fallback if dimensions don't match
                center_idx = L_q // 2
            
            # 3. Extract Map (Center Pixel Attention)
            # map_full = attn_probs[0, :, center_idx, :].detach().cpu()
            
            # --- New Logic: Target -> Cond0 Influence Map ---
            # We want to know: For EACH pixel in Target, how much attention does it pay to Cond0?
            # This answers: "Which parts of the generated image are most related to Cond0?"
            
            # Calculate start/end for Cond0
            # hidden_states structure: [Target, Cond0, Cond1, ...]
            # key_cat structure: [Text..., Target, Cond0, Cond1...]
            
            # Offset for Text
            offset_text = 0
            for txt_h in hidden_states2:
                offset_text += txt_h.shape[1]
                
            # Offset for Target (Image 0)
            len_target = hidden_states[0].shape[1]
            
            # Offset for Cond0 (Image 1) - if it exists
            if len(hidden_states) > 1:
                len_cond0 = hidden_states[1].shape[1]
                start_cond0 = offset_text + len_target
                end_cond0 = start_cond0 + len_cond0
                
                # Sum attention probs over Cond0 keys for each Target query
                # attn_probs: [B, Heads, L_target, S_total]
                # We want: [B, Heads, L_target] -> sum over Cond0 columns
                attn_cond0 = attn_probs[..., start_cond0:end_cond0].sum(dim=-1)
                
                # Average over heads -> [B, L_target]
                map_cond0_mean = attn_cond0.mean(dim=1).detach().cpu() # [B, L_target]
                map_cond0_max = attn_cond0.max(dim=1)[0].detach().cpu() # [B, L_target]
                
                stats_min = map_cond0_mean.min().item()
                stats_max = map_cond0_mean.max().item()
                stats_mean_val = map_cond0_mean.mean().item()
                
                ATTN_MAP_STORAGE[f"step{step_idx}_layer{layer_idx}_Target_pay_to_Cond0_Mean"] = map_cond0_mean[0]
                ATTN_MAP_STORAGE[f"step{step_idx}_layer{layer_idx}_Target_pay_to_Cond0_Max"] = map_cond0_max[0]
                
                print(f"DEBUG: Target->Cond0 Map (Mean). Min={stats_min:.6f}, Max={stats_max:.6f}, Mean={stats_mean_val:.6f}")
                
                if stats_max - stats_min < 1e-5:
                    print("  ⚠️ WARNING: Target->Cond0 Map is essentially FLAT (Low Variance). Visualization will be noise.")

            # Also keep Self-Attention (Target->Target) for reference if needed, 
            # but user specifically asked for Cond0 relation.
            # Let's keep the center-pixel logic for "Target_Center_pay_to_Self" just in case.
            map_full = attn_probs[0, :, center_idx, :].detach().cpu()
            
            # 4. Slice Self-Attention (Query attending to its own Key block)
            # We need to find where this query's keys are in S_total.
            # key_cat order: Text, Target, Cond0, Cond1...
            
            # Calculate start/end for this branch
            current_offset = 0
            # Add Text
            for txt_h in hidden_states2:
                current_offset += txt_h.shape[1]
            
            # Add previous image branches
            for prev_i in range(branch_idx):
                current_offset += hidden_states[prev_i].shape[1]
                
            start_idx = current_offset
            end_idx = current_offset + L_q
            
            if L_q > 0:
                map_self = map_full[:, start_idx : end_idx]
                
                stats_mean = map_self.mean().item()
                stats_std = map_self.std().item()
                print(f"DEBUG: {branch_name} Self-Attn. Full={map_full.shape}, SelfSlice={map_self.shape}, Mean={stats_mean:.6f}, Std={stats_std:.6f}")
                
                ATTN_MAP_STORAGE[f"step{step_idx}_layer{layer_idx}_{branch_name}"] = map_self.mean(dim=0)
            else:
                print(f"DEBUG: {branch_name} has 0 tokens, skipping Self-Attn visualization.")
        
        # Standard Forward
        attn_output = F.scaled_dot_product_attention(
            query, key_cat, value_cat
        ).to(query.dtype)
        attn_output = attn_output.transpose(1, 2).reshape(bs, -1, attn.heads * head_dim)
        attn_outputs.append(attn_output)

    h_out, h2_out = [], []
    for i, hidden_state in enumerate(hidden_states2):
        h2_out.append(attn.to_add_out(attn_outputs[i]))

    for i, hidden_state in enumerate(hidden_states):
        h = attn_outputs[i + h2_n]
        if getattr(attn, "to_out", None) is not None:
            with specify_lora((attn.to_out[0],), adapters[i + h2_n]):
                h = attn.to_out[0](h)
        h_out.append(h)

    return (h_out, h2_out) if h2_n else h_out

# 确保能加载 Solar 相关组件
import omini.pipeline.flux_omini_solar as flux_omini_solar

def load_config(config_path: str):
    """加载配置文件"""
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config


def scan_inference_data(data_dir: str):
    """
    扫描 inference_datas 文件夹
    结构:
    data_dir/
      ├── backgrounds/ (*_erased.png)
      ├── subjects/    (*_crop.png) [可选]
      └── masks/       (*_mask.png) [必须，用于 Solar Encoder]
    """
    data_dir = Path(data_dir)
    backgrounds_dir = data_dir / "backgrounds"
    subjects_dir = data_dir / "subjects"
    masks_dir = data_dir / "masks"
    
    samples = []
    
    if not backgrounds_dir.exists():
        print(f"❌ Error: backgrounds/ folder not found in {data_dir}")
        return samples
        
    if not masks_dir.exists():
        print(f"❌ Error: masks/ folder not found in {data_dir} (Required for Solar Optimization)")
        return samples
    
    # 扫描所有 background 文件
    for bg_file in sorted(backgrounds_dir.glob("*_erased.png")):
        # 提取 ID（去掉 _erased.png 后缀）
        sample_id = bg_file.stem.replace("_erased", "")
        
        # 查找对应的 mask (必须)
        mask_file = masks_dir / f"{sample_id}_mask.png"
        if not mask_file.exists():
            print(f"⚠️  Warning: Mask not found for {sample_id}, skipping.")
            continue
        
        # 查找对应的 subject（可选）
        subject_file = None
        if subjects_dir.exists():
            subject_file = subjects_dir / f"{sample_id}_crop.png"
            if not subject_file.exists():
                subject_file = None
        
        samples.append({
            "id": sample_id,
            "background": str(bg_file),
            "mask": str(mask_file),
            "subject": str(subject_file) if subject_file else None,
        })
    
    return samples


def create_default_subject(size):
    """创建默认的 subject 图像（灰色方块）"""
    return Image.new("RGB", size, (128, 128, 128))


def load_rgba_with_black_background(image_path: str) -> Image.Image:
    """加载 RGBA 图像，处理透明背景为黑色"""
    img = Image.open(image_path)
    if img.mode == 'RGB':
        return img
    if img.mode == 'RGBA':
        background = Image.new('RGB', img.size, (255, 255, 255))
        background.paste(img, mask=img.split()[3])
        return background
    return img.convert('RGB')


@torch.no_grad()
def batch_inference(
    model,
    samples,
    output_dir: str = "inference_results_solar/batch",
    target_size=(512, 512),
    condition_size=(512, 512),
    unified_subject_path: str = None,
    seed: int = 42
):
    os.makedirs(output_dir, exist_ok=True)
    
    # 获取 adapter
    subject_adapter = model.adapter_names[2]
    background_adapter = model.adapter_names[3]
    
    unified_prompt = "Place an aircraft at the specified position"
    
    print(f"\n{'='*70}")
    print(f"Batch Inference (Bg+Crop+Solar) on {len(samples)} samples")
    print(f"{'='*70}")
    print(f"Output directory: {output_dir}")
    print(f"Adapters: {subject_adapter}, {background_adapter}")
    print(f"Using Seed: {seed}")
    
    # 加载统一 subject (如果有)
    unified_subject_img = None
    if unified_subject_path and os.path.exists(unified_subject_path):
        try:
            print(unified_subject_path)
            unified_subject_img = load_rgba_with_black_background(unified_subject_path)
            unified_subject_img = unified_subject_img.resize(condition_size, Image.BILINEAR)
            print(f"✓ Loaded unified subject image from {unified_subject_path}")
        except Exception as e:
            print(f"⚠️  Failed to load unified subject: {e}")
            unified_subject_img = None
    else:
        print(f"⚠️  No unified subject provided or not found at {unified_subject_path}")
    
    for idx, sample in enumerate(samples):
        sample_id = sample["id"]
        print(f"\n[{idx+1}/{len(samples)}] Processing {sample_id}...")
        
        # 1. 准备图像
        try:
            background_img = Image.open(sample["background"]).convert("RGB").resize(condition_size, Image.BILINEAR)
            mask_img = Image.open(sample["mask"]).convert("L").resize(target_size, Image.NEAREST)
            
            if sample["subject"] is not None:
                subject_img = load_rgba_with_black_background(sample["subject"])
                subject_img = subject_img.resize(condition_size, Image.BILINEAR)
            elif unified_subject_img is not None:
                subject_img = unified_subject_img
            else:
                subject_img = create_default_subject(condition_size)
        except Exception as e:
            print(f"  ❌ Failed to load images: {e}")
            continue
            
        # 2. 构建 Conditions
        # Subject
        subject_condition = Condition(subject_img, subject_adapter, [-16, -32])
        # Background
        background_condition = Condition(background_img, background_adapter, [16, -32])
        
        # 3. Solar Optimization Parameters Calculation
        # 我们需要手动运行 Solar Encoder 来获取 solar_params
        # 参考 OminiSolarModel.training_step 中的逻辑
        
        # 3.1 准备 Solar Encoder 输入
        # 需要 Background 的 latents 和 Mask
        # 首先编码 Background
        # 修正：image_processor.preprocess 默认返回 [-1, 1]，但 diffusers 可能在未来版本改变行为
        # 这里我们手动确保它在 [-1, 1] 范围用于 encode_images (它内部会处理 shift/scale)
        # 但是！diffusers 的 VAE 期望输入通常是 [-1, 1]。
        # 这里的警告说：Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1]
        # 这意味着 image_processor.preprocess 如果接收 Tensor，期望是 [0, 1]。
        # 如果我们传的是 PIL Image，preprocess 会自动处理。
        # 让我们看看 preprocess 的调用：
        # bg_tensor = model.flux_pipe.image_processor.preprocess(background_img)
        # background_img 是 PIL Image，所以 preprocess 返回的是 Tensor。
        # 默认情况下，Flux 的 image_processor (VaeImageProcessor) 会将 PIL Image 归一化到 [-1, 1]。
        # 警告可能是因为我们在后续某个地方把这个 [-1, 1] 的 Tensor 又传给了某个期望 [0, 1] 的函数？
        # 
        # 检查 encode_images 函数 (在 flux_omini_solar.py 中):
        # def encode_images(pipeline, images):
        #     images = pipeline.image_processor.preprocess(images)
        #     ...
        # 
        # 这里我们在推理脚本中调用了 `model.flux_pipe.image_processor.preprocess(background_img)` 得到 bg_tensor ([-1, 1])
        # 然后调用 `flux_omini_solar.encode_images(model.flux_pipe, bg_tensor)`
        # 在 `encode_images` 内部，又调用了一次 `pipeline.image_processor.preprocess(images)`！
        # 这就是问题所在：第二次调用时，输入已经是 [-1, 1] 的 Tensor，而它期望 [0, 1]。
        
        # 解决方法：直接将 PIL Image 传给 encode_images，让它在内部做预处理。
        
        # Encode background to latents
        # 直接传递 PIL Image
        bg_latents, _ = flux_omini_solar.encode_images(model.flux_pipe, background_img)
        
        # Unpack / Reshape latents for CNN: [B, L, C] -> [B, C, H, W]
        B, L, C = bg_latents.shape
        H_latent = int(L ** 0.5)
        W_latent = H_latent
        bg_spatial = bg_latents.transpose(1, 2).view(B, C, H_latent, W_latent).to(torch.float32)
        
        # Prepare Mask Tensor [B, 1, H, W]
        # 注意: mask 需要是原始大小或与 latent 对齐，Solar Encoder 内部会对齐
        mask_tensor = torch.from_numpy(np.array(mask_img)).float() / 255.0
        mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(model.device) # [1, 1, H, W]
        
        # 3.2 运行 Solar Encoder
        context_vector = model.solar_encoder(bg_spatial, mask_tensor) # [1, 1024]
        
        # 3.3 运行 Projectors
        solar_params_list = []
        for proj in model.solar_projectors:
            params = proj(context_vector)
            scale, shift = params.chunk(2, dim=1)
            scale = scale.unsqueeze(1).to(model.dtype)
            shift = shift.unsqueeze(1).to(model.dtype)
            solar_params_list.append((scale, shift))
            
        # 4. 生成
        seed_everything(seed)
        generator = torch.Generator(device=model.device)
        generator.manual_seed(seed + idx)
        
        print(f"  Generating image...")
        try:
            # Reset Global Counter
            global GLOBAL_LAYER_COUNTER, ATTN_MAP_STORAGE
            GLOBAL_LAYER_COUNTER = 0
            ATTN_MAP_STORAGE = {}

            res = generate(
                model.flux_pipe,
                prompt=unified_prompt,
                conditions=[subject_condition, background_condition],
                height=target_size[1],
                width=target_size[0],
                num_inference_steps=28,
                guidance_scale=3.5,
                generator=generator,
                model_config=model.model_config,
                kv_cache=model.model_config.get("independent_condition", False),
                solar_params_list=solar_params_list, # 传入计算好的 Solar 参数
                transformer_kwargs={
                    "attn_forward": visualize_attn_forward,
                    "target_size": target_size
                }
            )
            
            # Visualization: Save Attention Map
            for key, attn_map in ATTN_MAP_STORAGE.items():
                img_len = attn_map.shape[0]
                print(f"DEBUG: Saving Map {key}. Tokens={img_len}")
                
                if img_len > 0:
                    img_attn = attn_map.float()
                    
                    # Reshape based on target_size (H_t, W_t)
                    # target_size is (Width, Height)
                    # Latents are downsampled by 16
                    H_latent = target_size[1] // 16
                    W_latent = target_size[0] // 16
                    
                    if img_len == H_latent * W_latent:
                        try:
                            # reshape to [H, W]
                            img_attn = img_attn.view(H_latent, W_latent)
                            
                            # Resize to target_size for visualization
                            # Need [1, 1, H, W] for interpolate
                            img_attn = img_attn.unsqueeze(0).unsqueeze(0)
                            img_attn = F.interpolate(img_attn, size=(target_size[1], target_size[0]), mode='bilinear', align_corners=False)
                            img_attn = img_attn.squeeze() # [H, W]
                            
                            # Normalize
                            img_attn = (img_attn - img_attn.min()) / (img_attn.max() - img_attn.min() + 1e-6)
                            
                            plt.figure(figsize=(10, 10))
                            sns.heatmap(img_attn.numpy(), cmap='viridis', cbar=False)
                            plt.axis('off')
                            save_path = os.path.join(output_dir, f"{sample['id']}_{key}.png")
                            plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
                            plt.close()
                            print(f"  Saved attention map to {save_path}")
                        except Exception as e:
                            print(f"  ❌ Failed to save heatmap: {e}")
                    else:
                         # Fallback for 1D or mismatched sizes
                        print(f"  ⚠️ Token count {img_len} matches neither square nor target_size {H_latent}x{W_latent}. Saving 1D plot.")
                        
                        # Fallback 1D
                        plt.figure(figsize=(10, 4))
                        plt.plot(img_attn.numpy())
                        plt.title(f"Attn {key} (1D)")
                        vis_path = os.path.join(output_dir, f"{sample['id']}_{key}_1d.png")
                        plt.savefig(vis_path)
                        plt.close()
                        print(f"  ✓ Saved {vis_path}")
            
            # Reset storage
            ATTN_MAP_STORAGE = {}

            # 5. 保存结果
            output_path = os.path.join(output_dir, f"{sample_id}_generated.jpg")
            res.images[0].save(output_path)
            print(f"  ✓ Saved to {output_path}")
            
            # 保存条件图像
            condition_dir = os.path.join(output_dir, "conditions")
            os.makedirs(condition_dir, exist_ok=True)
            subject_img.save(os.path.join(condition_dir, f"{sample_id}_subject.jpg"))
            background_img.save(os.path.join(condition_dir, f"{sample_id}_background.jpg"))
            
        except Exception as e:
            print(f"  ❌ Generation failed: {e}")
            continue

    print(f"\n{'='*70}")
    print(f"✓ Batch inference completed! Results saved to {output_dir}")


def main():
    parser = argparse.ArgumentParser(description="Batch Inference for Aircraft (Solar Optimization)")
    parser.add_argument("--config", type=str, default="./train/config/aircraft_bg_crop.yaml", help="Path to training config")
    parser.add_argument("--checkpoint", type=str, default=None, help="Path to checkpoint directory (containing training_state.pt and solar_components.pt)")
    parser.add_argument("--data_dir", type=str, default="./inference_datas", help="Directory containing inference data (backgrounds/, masks/, subjects/)")
    parser.add_argument("--output_dir", type=str, default="inference_results_solar/batch", help="Output directory")
    parser.add_argument("--unified_subject", type=str, default=None, help="Path to a single subject image to use for all samples")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    
    args = parser.parse_args()
    
    # 🔧 1. 设置全局确定性
    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    seed_everything(seed)
    
    print("="*70)
    print("Aircraft Batch Inference (Bg + Crop + Solar)")
    print("🔧 Deterministic mode enabled")
    print("="*70)
    
    # 1. 加载配置
    print(f"Loading config from {args.config}...")
    if not os.path.exists(args.config):
        print(f"❌ Error: Config file not found: {args.config}")
        return
    config = load_config(args.config)
    training_config = config["train"]
    dataset_config = training_config["dataset"]
    
    target_size = tuple(dataset_config["target_size"])
    condition_size = tuple(dataset_config["condition_size"])
    
    # 2. 自动查找 Checkpoint
    checkpoint_path = args.checkpoint
    # checkpoint_path = "runs_bg_crop/20251225-183523/ckpt/6000"
    # checkpoint_path = "runs_bg_crop/20251230-204328/ckpt/2000"
    checkpoint_path = "runs_bg_crop/20260104-200004/ckpt/8000"

    if checkpoint_path is None:
        print("\n❌ Error: No checkpoint found!")
        print("Please provide --checkpoint path (must contain solar_components.pt)")
        return

    print(f"Loading model from {checkpoint_path}...")
    
    # 3. 初始化模型
    # 注意：使用 OminiSolarModel
    # 关键修复：推理时如果我们要手动加载 LoRA (load_lora_weights)，
    # 我们应该避免在 init_lora 中创建 adapter，或者确保名字不冲突。
    # OminiModel.__init__ 会调用 init_lora，它会根据 adapter_names 创建 adapter。
    # 如果我们传入 lora_config，它就会创建。
    # 
    # 这里的冲突原因是：
    # 1. OminiModel 初始化时，根据 adapter_names=['subject', 'background'] 创建了空的 adapter (因为 lora_path=None, lora_config=training_config)
    # 2. 后面我们调用 load_lora_weights(..., adapter_name='subject')，试图再次创建/加载同名 adapter。
    # 
    # 解决方法：
    # 在初始化时不创建 LoRA adapter (设 lora_config=None)，然后手动加载。
    # 或者，利用 load_lora_weights 的加载机制，直接加载到已存在的 adapter 上？
    # load_lora_weights 文档说 "This will load the LoRA weights... into the model."
    # 如果 adapter 已存在，它应该覆盖权重。
    # 但报错是 "Adapter name subject already in use... please select a new adapter name"
    # 这意味着 load_lora_weights 认为它在创建一个新 adapter。
    # 
    # 查看 diffusers 源码，load_lora_weights 确实通常用于添加新 adapter。
    # 
    # 最干净的方法：初始化时不创建 adapter，让 load_lora_weights 去创建。
    
    model = OminiSolarModel(
        flux_pipe_id=config["flux_path"],
        lora_path=None,
        lora_config=None, # 设为 None，避免在 __init__ 中创建 adapter
        device="cuda",
        dtype=torch.bfloat16 if config["dtype"] == "bfloat16" else torch.float32,
        model_config=config.get("model", {}),
        adapter_names=[None, None, "subject", "background"], # 这些名字仅用于逻辑索引，不会触发创建（因为 lora_config=None）
        gradient_checkpointing=False,
    )
    # 恢复 model.adapter_set 以便后续使用
    model.adapter_set = set(["subject", "background"])
    model.training_config = training_config
    
    # 4. 加载权重
    
    # 加载 Solar Components
    solar_path = os.path.join(checkpoint_path, "solar_components.pt")
    if os.path.exists(solar_path):
        print(f"Loading Solar Components from {solar_path}")
        state = torch.load(solar_path, map_location=model.device)
        model.solar_encoder.load_state_dict(state["encoder"])
        model.solar_projectors.load_state_dict(state["projectors"])
    else:
        print(f"❌ Error: solar_components.pt not found in {checkpoint_path}")
        return

    # 加载 LoRA (subject, background)
    print("Loading LoRA adapters...")
    for adapter_name in ["subject", "background"]:
        lora_file = os.path.join(checkpoint_path, f"{adapter_name}.safetensors")
        if os.path.exists(lora_file):
            print(f"  Loading {adapter_name} from {lora_file}")
            model.flux_pipe.load_lora_weights(checkpoint_path, weight_name=f"{adapter_name}.safetensors", adapter_name=adapter_name)
        else:
            print(f"⚠️  Warning: {adapter_name}.safetensors not found.")
            
    # 显式激活所有 LoRA adapters (参考 batch_inference_aircraft_bg_crop.py)
    print("\n  Activating LoRA adapters...")
    adapter_list = list(model.adapter_set)
    if adapter_list:
        try:
            model.transformer.set_adapters(adapter_list)
            print(f"    ✓ Activated adapters: {adapter_list}")
        except Exception as e:
            print(f"    ⚠️  set_adapters failed: {e}")
            for adapter_name in adapter_list:
                try:
                    model.transformer.enable_adapters(adapter_name)
                    print(f"    ✓ Enabled adapter: {adapter_name}")
                except Exception as e2:
                    print(f"    ⚠️  Failed to enable {adapter_name}: {e2}")

    # 设置为评估模式
    model.eval()
    model.transformer.eval()
    model.flux_pipe.vae.eval()
    model.flux_pipe.text_encoder.eval()
    model.flux_pipe.text_encoder_2.eval()
    
    # 禁用所有 dropout
    for module in model.modules():
        if isinstance(module, torch.nn.Dropout):
            module.p = 0.0
            
    print("  ✓ Model loaded and set to eval mode (with adapter activation)")
    
    # 5. 扫描数据
    print(f"Scanning data directory: {args.data_dir}")
    samples = scan_inference_data(args.data_dir)
    if not samples:
        print("No samples found.")
        return
    print(f"✓ Found {len(samples)} samples")
        
    # 6. 运行推理
    batch_inference(
        model, 
        samples, 
        output_dir=args.output_dir,
        target_size=target_size,
        condition_size=condition_size,
        unified_subject_path=args.unified_subject,
        seed=args.seed
    )
    print("Done!")

if __name__ == "__main__":
    main()