#!/usr/bin/env python3
"""
Single Sample Inference for Aircraft (Bg + Crop + Low-Original Gated LoRA)
"""

import os
import sys
import yaml
import torch
import random
from pathlib import Path
from PIL import Image
import numpy as np
import argparse
import torch.nn.functional as F

# Add project root to path
project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))

from omini.pipeline.flux_omini_solar import Condition, generate, seed_everything
from omini.train_flux.trainer_mask_weighted import OminiModel

def load_config(config_path: str):
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config

def frequency_split_latents(latents, kernel_size=5):
    """Decompose latents into Low-Frequency and High-Frequency components."""
    B, L, C = latents.shape
    H = int(L**0.5)
    W = H
    x = latents.transpose(1, 2).view(B, C, H, W)
    pad = kernel_size // 2
    low = F.avg_pool2d(x, kernel_size=kernel_size, stride=1, padding=pad, count_include_pad=False)
    high = x - low
    low_flat = low.flatten(2).transpose(1, 2)
    high_flat = high.flatten(2).transpose(1, 2)
    return low_flat, high_flat

class LowFreqCondition(Condition):
    """Low Frequency Condition (Style/Lighting)"""
    def encode(self, pipe, empty=False):
        latents, ids = super().encode(pipe, empty)
        low, _ = frequency_split_latents(latents)
        return low, ids

def load_rgba_with_black_background(image_path: str) -> Image.Image:
    img = Image.open(image_path)
    if img.mode == 'RGB':
        return img
    if img.mode == 'RGBA':
        background = Image.new('RGB', img.size, (255, 255, 255))
        background.paste(img, mask=img.split()[3])
        return background
    return img.convert('RGB')

@torch.no_grad()
def single_inference(
    model,
    background_path,
    subject_path,
    output_path,
    target_size=(512, 512),
    condition_size=(512, 512),
    seed=42
):
    print(f"\nProcessing Single Sample:")
    print(f"  Background: {background_path}")
    print(f"  Subject:    {subject_path}")
    print(f"  Output:     {output_path}")
    
    # 1. Load Images
    try:
        background_img = Image.open(background_path).convert("RGB").resize(condition_size, Image.BILINEAR)
        subject_img = load_rgba_with_black_background(subject_path).resize(condition_size, Image.BILINEAR)
    except Exception as e:
        print(f"❌ Error loading images: {e}")
        return

    # 2. Conditions
    adapter_low = "subject_low"
    adapter_original = "subject_original"
    adapter_bg = "background"
    
    cond_low = LowFreqCondition(subject_img, adapter_low, [-16, -32])
    cond_original = Condition(subject_img, adapter_original, [-16, -32])
    cond_bg = Condition(background_img, adapter_bg, [16, -32])
    
    # 3. Generate
    seed_everything(seed)
    generator = torch.Generator(device=model.device)
    generator.manual_seed(seed)
    
    prompt = "Place an aircraft at the specified position"
    
    print("  Generating...")
    try:
        res = generate(
            model.flux_pipe,
            prompt=prompt,
            conditions=[cond_low, cond_original, cond_bg],
            height=target_size[1],
            width=target_size[0],
            num_inference_steps=28,
            guidance_scale=3.5,
            generator=generator,
            model_config=model.model_config,
            kv_cache=False,
        )
        
        res.images[0].save(output_path)
        print(f"✅ Saved result to {output_path}")
        
    except Exception as e:
        print(f"❌ Generation failed: {e}")
        import traceback
        traceback.print_exc()

def main():
    parser = argparse.ArgumentParser(description="Single Inference")
    parser.add_argument("--config", type=str, default="./train/config/aircraft_bg_crop.yaml")
    parser.add_argument("--checkpoint", type=str, default="runs_bg_crop/20260106-174238/ckpt/14000")
    parser.add_argument("--background", type=str, required=True, help="Path to background image")
    parser.add_argument("--subject", type=str, required=True, help="Path to subject image")
    parser.add_argument("--output", type=str, default="single_result.jpg", help="Path to save output")
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()
    
    if not os.path.exists(args.config):
        print(f"❌ Config not found: {args.config}")
        return

    config = load_config(args.config)
    
    print(f"Loading model from {args.checkpoint}...")
    model = OminiModel(
        flux_pipe_id=config["flux_path"],
        lora_path=None,
        lora_config=None,
        device="cuda",
        dtype=torch.bfloat16 if config["dtype"] == "bfloat16" else torch.float32,
        model_config=config.get("model", {}),
        adapter_names=[None, None, "subject_low", "subject_original", "background"],
        gradient_checkpointing=False,
    )
    
    # Manually load weights since OminiModel init skips it when lora_config is None
    print("Loading adapter weights...")
    model.load_adapters(args.checkpoint)
    model.eval()
    
    single_inference(
        model, 
        args.background, 
        args.subject, 
        args.output,
        seed=args.seed
    )

if __name__ == "__main__":
    main()