#!/usr/bin/env python3
"""
航空目标推理脚本 - Bg + Crop + Structured (Texture Guidance)
使用 aircraft_bg_crop 训练的模型进行推理，并加上纹理一致性引导
"""

import os
import sys
import yaml
import torch
import random
import numpy as np
from pathlib import Path
from PIL import Image
import torch.nn.functional as F
import torchvision.transforms as T

# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))

from omini.train_flux.train_aircraft_bg_crop import AircraftBackgroundCropDataset
from omini.pipeline.flux_omini import Condition, generate, seed_everything
from omini.train_flux.trainer import OminiModel


def load_config(config_path: str):
    """加载配置文件"""
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config

# --- Texture Consistency Energy Guidance Logic ---

def gram_matrix(input_tensor, mask_area=None):
    """
    计算格拉姆矩阵: 衡量特征通道之间的相关性 (即纹理风格)
    input: [B, C, H, W]
    output: [B, C, C]
    """
    # Force float32 for precision
    input_tensor = input_tensor.float()
    B, C, H, W = input_tensor.size()
    features = input_tensor.view(B, C, H * W)
    G = torch.bmm(features, features.transpose(1, 2))
    
    if mask_area is not None and mask_area > 0:
        return G.div(C * mask_area)
    else:
        return G.div(C * H * W)

def get_surround_mask(box_mask, dilation_pixels=5):
    """获取包围盒外围一圈的掩码 (Dilated - Original)"""
    k = 2 * dilation_pixels + 1
    padding = dilation_pixels
    dilated_mask = F.max_pool2d(box_mask, kernel_size=k, stride=1, padding=padding)
    surround_mask = dilated_mask - box_mask
    return surround_mask

def apply_texture_energy_guidance(latents, box_mask, steps=3, scale=100.0):
    """Texture Energy Guidance Optimization"""
    B, L, C = latents.shape
    H = int(L**0.5)
    W = H
    
    # Cast to float32 for optimization
    original_dtype = latents.dtype
    x_in = latents.transpose(1, 2).view(B, C, H, W).detach().clone().float()
    x_in.requires_grad_(True)
    
    # Prepare masks
    if box_mask.shape[-1] != W:
        box_mask = F.interpolate(box_mask, size=(H, W), mode='nearest')
    
    surround_mask = get_surround_mask(box_mask, dilation_pixels=5)
    
    # Optimization Loop
    with torch.enable_grad():
        for i in range(steps):
            obj_feat = x_in * box_mask
            env_feat = x_in * surround_mask
            
            area_obj = box_mask.sum().item()
            area_env = surround_mask.sum().item()
            
            G_obj = gram_matrix(obj_feat, mask_area=area_obj)
            G_env = gram_matrix(env_feat, mask_area=area_env)
            
            energy = F.mse_loss(G_obj, G_env)
            
            if x_in.grad is not None:
                x_in.grad.zero_()
            energy.backward()
            
            # Normalize gradient update
            with torch.no_grad():
                grad = x_in.grad
                grad_norm = grad.norm() + 1e-8
                normalized_grad = grad / grad_norm
                step_size = 0.1 # Fixed step size
                x_in = x_in - step_size * normalized_grad
                
            x_in.requires_grad_(True)
        
    return x_in.flatten(2).transpose(1, 2).detach().to(original_dtype)

# --- End Texture Guidance Logic ---


def inference_on_training_samples(
    model,
    dataset,
    num_samples: int = 10,
    output_dir: str = "inference_results_bg_crop_structured",
    seed: int = 42
):
    """
    对训练集样本进行推理
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # 获取配置
    condition_size = model.training_config["dataset"]["condition_size"]
    target_size = model.training_config["dataset"]["target_size"]
    inter_condition_attention = model.model_config.get("inter_condition_attention", False)
    
    # 获取两个 adapter
    subject_adapter = model.adapter_names[2]
    background_adapter = model.adapter_names[3]
    
    print(f"\n{'='*70}")
    print(f"Inference on {num_samples} training samples (Bg + Crop + Structured)")
    print(f"{'='*70}")
    print(f"Output directory: {output_dir}")
    print(f"Adapters: {subject_adapter}, {background_adapter}")
    
    # 遍历前 N 个样本
    num_samples = min(num_samples, len(dataset))
    
    # 创建一次数据集实例用于获取 PIL 图像
    pil_dataset = dataset
    original_return_pil = pil_dataset.return_pil_image
    pil_dataset.return_pil_image = True
    
    # 使用 torch.no_grad() 节省显存并加速推理
    with torch.no_grad():
        for idx in range(num_samples):
            print(f"\n[{idx+1}/{num_samples}] Processing sample {idx}...")
            
            sample = pil_dataset[idx]
            
            # 提取条件图像
            subject_img = sample["condition_0"]  # PIL Image (Subject)
            if isinstance(subject_img, torch.Tensor):
                subject_img = T.ToPILImage()(subject_img)

            background_img = sample["condition_1"]  # PIL Image (Background)
            if isinstance(background_img, torch.Tensor):
                background_img = T.ToPILImage()(background_img)
            prompt = sample["description"]
            
            # Try to get mask for saving and guidance
            mask_img = None
            if "target_mask" in sample:
                if isinstance(sample["target_mask"], Image.Image):
                    mask_img = sample["target_mask"]
                elif isinstance(sample["target_mask"], torch.Tensor):
                    m = sample["target_mask"]
                    if m.dim() == 3: m = m.squeeze(0)
                    mask_img = T.ToPILImage()(m)
            elif "mask" in sample:
                 if isinstance(sample["mask"], str):
                     try: mask_img = Image.open(sample["mask"]).convert("L")
                     except: pass
                 elif isinstance(sample["mask"], Image.Image):
                     mask_img = sample["mask"]
            
            if mask_img is None:
                 # Create dummy mask if missing
                mask_img = Image.new("L", target_size, 0)
                w, h = target_size
                mask_img.paste(255, (w//4, h//4, w*3//4, h*3//4))
            
            # Prepare mask tensor for guidance
            mask_np = np.array(mask_img.resize(target_size, Image.NEAREST))
            mask_tensor = torch.from_numpy(mask_np).float() / 255.0
            if mask_tensor.dim() == 2: # [H, W]
                mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
            elif mask_tensor.dim() == 3: 
                if mask_tensor.shape[2] == 1:
                    mask_tensor = mask_tensor.permute(2, 0, 1).unsqueeze(0)
                else:
                    mask_tensor = mask_tensor.unsqueeze(0)
            mask_tensor = mask_tensor.to("cuda")

            # 创建条件 (Subject + Background)
            subject_condition = Condition(
                subject_img, 
                subject_adapter, 
                [-16, -32]
            )
            background_condition = Condition(
                background_img, 
                background_adapter, 
                [16, -32]
            )
            
            # Texture Guidance Callback
            def texture_guidance_callback(pipe, step_index, timestep, callback_kwargs):
                total_steps = 28
                # Apply guidance in the first 30% of steps
                if step_index < int(total_steps * 0.3):
                    latents = callback_kwargs["latents"]
                    latents_new = apply_texture_energy_guidance(
                        latents, 
                        mask_tensor, 
                        steps=3, 
                        scale=100.0
                    )
                    callback_kwargs["latents"][:] = latents_new
                return callback_kwargs

            # 准备生成器
            generator = torch.Generator(device=model.device)
            generator.manual_seed(seed + idx)
            
            print(f"  Generating image...")
            try:
                res = generate(
                    model.flux_pipe,
                    prompt=prompt,
                    conditions=[subject_condition, background_condition],
                    height=target_size[1],
                    width=target_size[0],
                    num_inference_steps=28,
                    guidance_scale=3.5,
                    generator=generator,
                    model_config=model.model_config,
                    kv_cache=model.model_config.get("independent_condition", False),
                    inter_condition_attention=inter_condition_attention,
                    callback_on_step_end=texture_guidance_callback,
                    callback_on_step_end_tensor_inputs=["latents"],
                )
                
                output_path = os.path.join(output_dir, f"sample_{idx}_generated.jpg")
                res.images[0].save(output_path)
                print(f"  ✓ Saved to {output_path}")
                
                # Save conditions and GT
                condition_dir = os.path.join(output_dir, "conditions")
                os.makedirs(condition_dir, exist_ok=True)
                subject_img.save(os.path.join(condition_dir, f"sample_{idx}_subject.jpg"))
                background_img.save(os.path.join(condition_dir, f"sample_{idx}_background.jpg"))
                if mask_img:
                    mask_img.save(os.path.join(condition_dir, f"sample_{idx}_mask.jpg"))
                
                original_img = sample["image"]
                if isinstance(original_img, torch.Tensor):
                    original_img = T.ToPILImage()(original_img)
                original_img.save(os.path.join(condition_dir, f"sample_{idx}_original.jpg"))
                
            except Exception as e:
                print(f"  ❌ Generation failed: {e}")
                import traceback
                traceback.print_exc()
                continue
    
    # 恢复 dataset 状态
    pil_dataset.return_pil_image = original_return_pil
    
    print(f"\n{'='*70}")
    print(f"✓ Inference completed! Results saved to {output_dir}")


def main():
    # 🔧 在程序开始时就设置全局确定性行为
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    seed_everything(seed)
    
    print("="*70)
    print("🔧 Deterministic mode enabled for reproducible inference")
    print("="*70)
    
    # 获取配置路径
    config_path = os.environ.get("OMINI_CONFIG", "./train/config/aircraft_bg_crop_inf.yaml")
    print(f"Loading config from {config_path}")
    config = load_config(config_path)
    
    # 1. 加载数据集
    training_config = config["train"]
    dataset_config = training_config["dataset"]
    
    print("Loading dataset...")
    dataset = AircraftBackgroundCropDataset(
        dataset_root=dataset_config["dataset_root"],
        condition_size=tuple(dataset_config["condition_size"]),
        target_size=tuple(dataset_config["target_size"]),
        drop_text_prob=0.0, # 推理时不 drop
        drop_subject_prob=0.0,
        drop_background_prob=0.0,
        background_blur_prob=0.0, # 推理时不 blur
        augmentation_prob=0.0, # 推理时不 augment
        return_pil_image=True # 直接获取 PIL Image
    )
    print(f"Dataset size: {len(dataset)}")
    
    # 2. 初始化模型
    print("\n[2/3] Loading model...")
    
    # 查找 checkpoint
    checkpoint_path = training_config.get("resume_from_checkpoint", None)
    # checkpoint_path = "runs_bg_crop/20251217-161240/ckpt/12000"
    checkpoint_path = "runs_bg_crop/20260106-174326/ckpt/12000"
    if checkpoint_path is None or not os.path.exists(checkpoint_path):
        save_path = training_config.get("save_path", "runs_bg_crop")
        wandb_name = training_config.get("wandb", {}).get("name", "aircraft_bg_crop")
        # 尝试直接在 save_path 下找
        if os.path.exists(save_path):
            # 优先检查 wandb_name 子目录
            ckpt_dir = os.path.join(save_path, wandb_name, "ckpt")
            if not os.path.exists(ckpt_dir):
                # 如果找不到，尝试查找 save_path 下最新的运行目录
                run_dirs = [d for d in os.listdir(save_path) if os.path.isdir(os.path.join(save_path, d))]
                if run_dirs:
                    latest_run = sorted(run_dirs)[-1]
                    ckpt_dir = os.path.join(save_path, latest_run, "ckpt")
            
            if os.path.exists(ckpt_dir):
                checkpoints = [d for d in os.listdir(ckpt_dir) if os.path.isdir(os.path.join(ckpt_dir, d)) and d.isdigit()]
                if checkpoints:
                    latest_ckpt = max([int(c) for c in checkpoints])
                    checkpoint_path = os.path.join(ckpt_dir, str(latest_ckpt))
                    print(f"  ✓ Found latest checkpoint: {checkpoint_path}")

    if checkpoint_path is None:
        print("\n❌ Error: No checkpoint found!")
        print("Please train the model first or specify a checkpoint path.")
        return

    # 初始化模型
    model = OminiModel(
        flux_pipe_id=config["flux_path"],
        lora_path=checkpoint_path, # 这里传入 checkpoint 路径，OminiModel 会尝试加载
        lora_config=training_config.get("lora_config", None),
        device="cuda",
        dtype=torch.bfloat16 if config["dtype"] == "bfloat16" else torch.float32,
        model_config=config.get("model", {}),
        adapter_names=[None, None, "subject", "background"], 
    )
    model.training_config = training_config
    
    # 🔧 关键修复：显式激活所有 LoRA adapters
    print("\n  Activating LoRA adapters...")
    adapter_list = list(model.adapter_set)
    if adapter_list:
        try:
            # 方法1：使用 set_adapters 激活所有 adapter
            model.transformer.set_adapters(adapter_list)
            print(f"    ✓ Activated adapters: {adapter_list}")
        except Exception as e:
            print(f"    ⚠️  set_adapters failed: {e}")
            # 方法2：逐个启用 adapter
            for adapter_name in adapter_list:
                try:
                    model.transformer.enable_adapters(adapter_name)
                    print(f"    ✓ Enabled adapter: {adapter_name}")
                except Exception as e2:
                    print(f"    ⚠️  Failed to enable {adapter_name}: {e2}")
        
        # 验证 adapter 状态
        print("    Verifying adapter activation...")
        for name, module in model.transformer.named_modules():
            if hasattr(module, 'active_adapters'):
                if module.active_adapters:
                    print(f"      Active adapters in {name}: {module.active_adapters}")
                    break
    
    # 设置为评估模式
    model.eval()  # 设置整个模型为 eval 模式
    model.transformer.eval()  # 明确设置 transformer 为 eval 模式（双重保险）
    model.flux_pipe.vae.eval()  # 确保 VAE 也在 eval 模式
    model.flux_pipe.text_encoder.eval()
    model.flux_pipe.text_encoder_2.eval()
    
    # 🔧 关键修复：禁用所有 dropout 和随机层
    for module in model.modules():
        if isinstance(module, torch.nn.Dropout):
            module.p = 0.0
            
    print("  ✓ Model loaded and set to eval mode (with deterministic settings)")

    # 3. 运行推理
    print("\n[3/3] Running inference...")
    inference_on_training_samples(model, dataset, num_samples=len(dataset))


if __name__ == "__main__":
    main()