#!/usr/bin/env python3
"""
航空目标批量推理脚本 (Bg + Crop + Solar + Low-Original + Texture Energy Guidance)
增加基于纹理势能的引导 (Texture Consistency Energy Guidance) 以增强背景融合
"""

import os
import sys
import yaml
import torch
import random
from pathlib import Path
from PIL import Image
import numpy as np
import argparse
import torch.nn.functional as F
import math

try:
    import cv2
except ImportError:
    print("⚠️  OpenCV not found. Structured Noise will fallback to simple mask addition.")
    cv2 = None

# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))

from omini.pipeline.flux_omini_solar import Condition, generate, seed_everything
import omini.pipeline.flux_omini_solar as flux_omini_solar
from omini.train_flux.train_aircraft_bg_crop_solar import OminiSolarModel

def load_config(config_path: str):
    """加载配置文件"""
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config


def frequency_split_latents(latents, kernel_size=5):
    """
    Decompose latents into Low-Frequency and High-Frequency components.
    """
    B, L, C = latents.shape
    H = int(L**0.5)
    W = H
    
    # Reshape to spatial: [B, C, H, W]
    x = latents.transpose(1, 2).view(B, C, H, W)
    
    # Low Pass
    pad = kernel_size // 2
    low = F.avg_pool2d(x, kernel_size=kernel_size, stride=1, padding=pad, count_include_pad=False)
    
    # High Pass
    high = x - low
    
    # Flatten back to [B, L, C]
    low_flat = low.flatten(2).transpose(1, 2)
    high_flat = high.flatten(2).transpose(1, 2)
    
    return low_flat, high_flat


class LowFreqCondition(Condition):
    """Low Frequency Condition (Style/Lighting)"""
    def encode(self, pipe, empty=False):
        latents, ids = super().encode(pipe, empty)
        low, _ = frequency_split_latents(latents)
        return low, ids


def scan_inference_data(data_dir: str):
    """
    扫描 inference_datas 文件夹
    结构:
    data_dir/
      ├── backgrounds/ (*_erased.png)
      ├── subjects/    (*_crop.png) [可选]
      └── masks/       (*_mask.png) [必须，用于 Solar Encoder]
    """
    data_dir = Path(data_dir)
    backgrounds_dir = data_dir / "backgrounds"
    subjects_dir = data_dir / "subjects"
    masks_dir = data_dir / "masks"
    
    samples = []
    
    if not backgrounds_dir.exists():
        print(f"❌ Error: backgrounds/ folder not found in {data_dir}")
        return samples
        
    if not masks_dir.exists():
        print(f"❌ Error: masks/ folder not found in {data_dir} (Required for Solar/Structured Noise)")
        return samples
        
    # 扫描所有 background 文件
    for bg_file in sorted(backgrounds_dir.glob("*_erased.png")):
        # 提取 ID（去掉 _erased.png 后缀）
        sample_id = bg_file.stem.replace("_erased", "")
        
        # 查找对应的 mask
        mask_file = masks_dir / f"{sample_id}_mask.png"
        if not mask_file.exists():
            print(f"⚠️  Warning: Mask not found for {sample_id}, skipping.")
            continue
        
        # 查找对应的 subject（可选）
        subject_file = None
        if subjects_dir.exists():
            subject_file = subjects_dir / f"{sample_id}_crop.png"
            if not subject_file.exists():
                subject_file = None
        
        samples.append({
            "id": sample_id,
            "background": str(bg_file),
            "mask": str(mask_file),
            "subject": str(subject_file) if subject_file else None,
        })
    
    return samples


def create_default_subject(size):
    """创建默认的 subject 图像（灰色方块）"""
    return Image.new("RGB", size, (128, 128, 128))


def load_rgba_with_black_background(image_path: str) -> Image.Image:
    """加载 RGBA 图像，处理透明背景为黑色"""
    img = Image.open(image_path)
    if img.mode == 'RGB':
        return img
    if img.mode == 'RGBA':
        background = Image.new('RGB', img.size, (255, 255, 255))
        background.paste(img, mask=img.split()[3])
        return background
    return img.convert('RGB')


# --- Texture Consistency Energy Guidance Logic ---

def gram_matrix(input_tensor, mask_area=None):
    """
    计算格拉姆矩阵: 衡量特征通道之间的相关性 (即纹理风格)
    input: [B, C, H, W]
    output: [B, C, C]
    """
    # Force float32 for precision
    input_tensor = input_tensor.float()
    B, C, H, W = input_tensor.size()
    features = input_tensor.view(B, C, H * W)
    G = torch.bmm(features, features.transpose(1, 2))
    
    if mask_area is not None and mask_area > 0:
        return G.div(C * mask_area)
    else:
        return G.div(C * H * W)

def get_surround_mask(box_mask, dilation_pixels=5):
    """
    获取包围盒外围一圈的掩码 (Dilated - Original)
    box_mask: [1, 1, H, W]
    """
    # Simple dilation using max_pool2d
    # kernel_size = 2 * dilation + 1
    k = 2 * dilation_pixels + 1
    padding = dilation_pixels
    dilated_mask = F.max_pool2d(box_mask, kernel_size=k, stride=1, padding=padding)
    
    surround_mask = dilated_mask - box_mask
    return surround_mask

def apply_texture_energy_guidance(latents, box_mask, steps=5, scale=500.0):
    """
    Texture Energy Guidance Optimization
    latents: [B, L, C] (Flux format)
    box_mask: [1, 1, H, W] (resized to latent spatial dim)
    """
    B, L, C = latents.shape
    H = int(L**0.5)
    W = H
    
    # 1. Reshape to spatial [B, C, H, W] for masking
    # Cast to float32 for optimization
    original_dtype = latents.dtype
    x_in = latents.transpose(1, 2).view(B, C, H, W).detach().clone().float()
    x_in.requires_grad_(True)
    
    # Debug: Check Inputs
    # print(f"    [Debug] Latent Shape: {x_in.shape}, Mean: {x_in.mean().item():.4f}, Std: {x_in.std().item():.4f}")
    # print(f"    [Debug] Original Mask Shape: {box_mask.shape}, Sum: {box_mask.sum().item()}")
    
    # 2. Prepare masks
    # Ensure mask matches latent resolution
    if box_mask.shape[-1] != W:
        box_mask = F.interpolate(box_mask, size=(H, W), mode='nearest')
        # print(f"    [Debug] Resized Mask Shape: {box_mask.shape}, Sum: {box_mask.sum().item()}")
    
    surround_mask = get_surround_mask(box_mask, dilation_pixels=5)
    # print(f"    [Debug] Surround Mask Sum: {surround_mask.sum().item()}")
    
    # Debug: Check Mask Validity
    if box_mask.sum() == 0 or surround_mask.sum() == 0:
        print(f"    ⚠️  [Texture Guidance] Warning: Empty Mask! Box: {box_mask.sum().item()}, Surround: {surround_mask.sum().item()}")
    
    # 3. Optimization Loop
    # Note: Using simple gradient descent manually as it's lightweight
    
    with torch.enable_grad():
        for i in range(steps):
            # 1. Extract regions
            # obj_feat: [B, C, H, W] (masked)
            obj_feat = x_in * box_mask
            env_feat = x_in * surround_mask
            
            # Debug: Check Features
            # if i == 0:
            #     print(f"    [Debug] Obj Feat Sum: {obj_feat.abs().sum().item()}, Non-Zero: {(obj_feat.abs() > 1e-6).sum().item()}")
            #     print(f"    [Debug] Env Feat Sum: {env_feat.abs().sum().item()}, Non-Zero: {(env_feat.abs() > 1e-6).sum().item()}")
            
            # 2. Calculate Energy (Style Difference)
            # We want object's texture style to match environment's
            # Normalize by actual mask area to keep values in reasonable range
            area_obj = box_mask.sum().item()
            area_env = surround_mask.sum().item()
            
            G_obj = gram_matrix(obj_feat, mask_area=area_obj)
            G_env = gram_matrix(env_feat, mask_area=area_env)
            
            # Energy = MSE(Gram_obj, Gram_env)
            energy = F.mse_loss(G_obj, G_env)
            
            # 3. Backward
            if x_in.grad is not None:
                x_in.grad.zero_()
            energy.backward()
            
            # Debug Print
            if i == 0: 
                grad_mean = x_in.grad.abs().mean().item()
                print(f"    [Texture Guidance] Energy: {energy.item():.6e} | Grad: {grad_mean:.6e}")
            
            # 4. Update x (Gradient Descent with Normalization)
            # We want to minimize energy, so move against gradient
            # Normalize gradient to ensure constant step size
            with torch.no_grad():
                grad = x_in.grad
                grad_norm = grad.norm() + 1e-8
                # Use a fixed step size (e.g., 0.05 * latent_std)
                # Since latent std is approx 1.0, we use 0.05
                normalized_grad = grad / grad_norm
                
                # Effective step size: scale * 0.01 (assuming scale is around 1-10)
                # Let's redefine scale to be the step size directly.
                # If user passed scale=2000, that's too huge for normalized grad.
                # Let's use a reasonable step size like 0.1
                step_size = 0.1
                
                x_in = x_in - step_size * normalized_grad
                
            # Reset grad for next step
            x_in.requires_grad_(True)
        
    # 4. Return updated latents in Flux format [B, L, C]
    return x_in.flatten(2).transpose(1, 2).detach().to(original_dtype)

# --- End Texture Guidance Logic ---


@torch.no_grad()
def batch_inference(
    model,
    samples,
    output_dir: str = "inference_results_solar_low_original_structured/batch",
    target_size=(512, 512),
    condition_size=(512, 512),
    unified_subject_path: str = None,
    seed: int = 42
):
    os.makedirs(output_dir, exist_ok=True)
    
    # Adapter Names
    adapter_low = "subject_low"
    adapter_original = "subject_original"
    adapter_bg = "background"
    
    unified_prompt = "Place an aircraft at the specified position"
    
    print(f"\n{'='*70}")
    print(f"Batch Inference (Structured Noise) on {len(samples)} samples")
    print(f"{'='*70}")
    print(f"Output directory: {output_dir}")
    print(f"Adapters: {adapter_low}, {adapter_original}, {adapter_bg}")
    print(f"Using Seed: {seed}")
    
    # 加载统一 subject (如果有)
    unified_subject_img = None
    if unified_subject_path and os.path.exists(unified_subject_path):
        try:
            print(unified_subject_path)
            unified_subject_img = load_rgba_with_black_background(unified_subject_path)
            unified_subject_img = unified_subject_img.resize(condition_size, Image.BILINEAR)
            print(f"✓ Loaded unified subject image from {unified_subject_path}")
        except Exception as e:
            print(f"⚠️  Failed to load unified subject: {e}")
            unified_subject_img = None
    else:
        print(f"⚠️  No unified subject provided or not found at {unified_subject_path}")
    
    for idx, sample in enumerate(samples):
        sample_id = sample["id"]
        print(f"\n[{idx+1}/{len(samples)}] Processing {sample_id}...")
        
        # 1. 准备图像
        try:
            background_img = Image.open(sample["background"]).convert("RGB").resize(condition_size, Image.BILINEAR)
            mask_img = Image.open(sample["mask"]).convert("L").resize(target_size, Image.NEAREST)
            
            if sample["subject"] is not None:
                subject_img = load_rgba_with_black_background(sample["subject"])
                subject_img = subject_img.resize(condition_size, Image.BILINEAR)
            elif unified_subject_img is not None:
                subject_img = unified_subject_img
            else:
                subject_img = create_default_subject(condition_size)
        except Exception as e:
            print(f"  ❌ Failed to load images: {e}")
            continue
            
        # 2. 构建 Conditions
        cond_low = LowFreqCondition(subject_img, adapter_low, [-16, -32])
        cond_original = Condition(subject_img, adapter_original, [-16, -32])
        cond_bg = Condition(background_img, adapter_bg, [16, -32])
        
        # 3. Solar Parameters Calculation
        bg_latents, _ = flux_omini_solar.encode_images(model.flux_pipe, background_img)
        B, L, C = bg_latents.shape
        H_latent = int(L ** 0.5)
        W_latent = H_latent
        bg_spatial = bg_latents.transpose(1, 2).view(B, C, H_latent, W_latent).to(torch.float32)
        
        mask_tensor = torch.from_numpy(np.array(mask_img)).float() / 255.0
        mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(model.device)
        
        context_vector = model.solar_encoder(bg_spatial, mask_tensor)
        
        solar_params_list = []
        for proj in model.solar_projectors:
            params = proj(context_vector)
            scale, shift = params.chunk(2, dim=1)
            scale = scale.unsqueeze(1).to(model.dtype)
            shift = shift.unsqueeze(1).to(model.dtype)
            solar_params_list.append((scale, shift))
            
        # 4. Texture Energy Guidance Setup
        # We will use a callback to inject guidance during early steps
        
        def texture_guidance_callback(pipe, step_index, timestep, callback_kwargs):
            # Apply in first 30% of steps (adjust as needed)
            total_steps = 28
            if step_index < int(total_steps * 0.8):
                latents = callback_kwargs["latents"]
                # Apply texture guidance
                # Note: mask_tensor is available from outer scope [1, 1, H, W]
                latents_new = apply_texture_energy_guidance(
                    latents, 
                    mask_tensor, 
                    steps=5, 
                    scale=2000.0
                )
                callback_kwargs["latents"][:] = latents_new
            return callback_kwargs
        
        # 5. 生成
        seed_everything(seed)
        generator = torch.Generator(device=model.device)
        generator.manual_seed(seed + idx)
        
        print(f"  Generating image with Texture Energy Guidance...")
        try:
            res = generate(
                model.flux_pipe,
                prompt=unified_prompt,
                conditions=[cond_low, cond_original, cond_bg],
                height=target_size[1],
                width=target_size[0],
                num_inference_steps=28,
                guidance_scale=3.5,
                generator=generator,
                model_config=model.model_config,
                kv_cache=model.model_config.get("independent_condition", False),
                solar_params_list=solar_params_list,
                callback_on_step_end=texture_guidance_callback,
                callback_on_step_end_tensor_inputs=["latents"],
            )
            
            output_path = os.path.join(output_dir, f"{sample_id}_generated.jpg")
            res.images[0].save(output_path)
            print(f"  ✓ Saved to {output_path}")
            
            condition_dir = os.path.join(output_dir, "conditions")
            os.makedirs(condition_dir, exist_ok=True)
            subject_img.save(os.path.join(condition_dir, f"{sample_id}_subject.jpg"))
            background_img.save(os.path.join(condition_dir, f"{sample_id}_background.jpg"))
            
        except Exception as e:
            print(f"  ❌ Generation failed: {e}")
            import traceback
            traceback.print_exc()
            continue
            
    print(f"\n{'='*70}")
    print(f"✓ Batch inference completed! Results saved to {output_dir}")


def main():
    parser = argparse.ArgumentParser(description="Batch Inference (Solar + Low-Original + Structured Noise)")
    parser.add_argument("--config", type=str, default="./train/config/aircraft_bg_crop.yaml")
    parser.add_argument("--checkpoint", type=str, default=None)
    parser.add_argument("--data_dir", type=str, default="./inference_datas")
    parser.add_argument("--output_dir", type=str, default="inference_results_solar_low_original_structured/batch")
    parser.add_argument("--unified_subject", type=str, default=None)
    parser.add_argument("--seed", type=int, default=42)
    
    args = parser.parse_args()
    
    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    seed_everything(seed)
    
    print("="*70)
    print("Aircraft Batch Inference (Solar + Low-Original + Structured Noise)")
    print("="*70)
    
    if not os.path.exists(args.config):
        print(f"❌ Error: Config file not found: {args.config}")
        return
    config = load_config(args.config)
    training_config = config["train"]
    dataset_config = training_config["dataset"]
    
    target_size = tuple(dataset_config["target_size"])
    condition_size = tuple(dataset_config["condition_size"])
    
    checkpoint_path = args.checkpoint
    checkpoint_path = "runs_bg_crop/20260106-174326/ckpt/14000"
    if checkpoint_path is None:
        print("\n❌ Error: No checkpoint found! Please provide --checkpoint path")
        return

    print(f"Loading model from {checkpoint_path}...")
    
    model = OminiSolarModel(
        flux_pipe_id=config["flux_path"],
        lora_path=None,
        lora_config=None,
        device="cuda",
        dtype=torch.bfloat16 if config["dtype"] == "bfloat16" else torch.float32,
        model_config=config.get("model", {}),
        adapter_names=[None, None, "subject_low", "subject_original", "background"],
        gradient_checkpointing=False,
    )
    model.adapter_set = set(["subject_low", "subject_original", "background"])
    model.training_config = training_config
    
    # Load Weights
    solar_path = os.path.join(checkpoint_path, "solar_components.pt")
    if os.path.exists(solar_path):
        print(f"Loading Solar Components from {solar_path}")
        state = torch.load(solar_path, map_location=model.device)
        model.solar_encoder.load_state_dict(state["encoder"])
        model.solar_projectors.load_state_dict(state["projectors"])
    else:
        print(f"❌ Error: solar_components.pt not found in {checkpoint_path}")
        return
        
    print("Loading LoRA adapters...")
    for adapter_name in ["subject_low", "subject_original", "background"]:
        lora_file = os.path.join(checkpoint_path, f"{adapter_name}.safetensors")
        if os.path.exists(lora_file):
            print(f"  Loading {adapter_name} from {lora_file}")
            model.flux_pipe.load_lora_weights(checkpoint_path, weight_name=f"{adapter_name}.safetensors", adapter_name=adapter_name)
        else:
            print(f"⚠️  Warning: {adapter_name}.safetensors not found.")
            
    print("\n  Activating LoRA adapters...")
    adapter_list = list(model.adapter_set)
    if adapter_list:
        try:
            model.transformer.set_adapters(adapter_list)
            print(f"    ✓ Activated adapters: {adapter_list}")
        except Exception as e:
            print(f"    ⚠️  set_adapters failed: {e}")
            for adapter_name in adapter_list:
                try:
                    model.transformer.enable_adapters(adapter_name)
                    print(f"    ✓ Enabled adapter: {adapter_name}")
                except Exception as e2:
                    print(f"    ⚠️  Failed to enable {adapter_name}: {e2}")

    model.eval()
    
    samples = scan_inference_data(args.data_dir)
    if not samples:
        print("No samples found.")
        return
        
    batch_inference(
        model, 
        samples, 
        output_dir=args.output_dir,
        target_size=target_size,
        condition_size=condition_size,
        unified_subject_path=args.unified_subject,
        seed=args.seed
    )
    print("Done!")

if __name__ == "__main__":
    main()