#!/usr/bin/env python3
"""
Single Sample Inference for Aircraft (Bg + Crop + Solar + Low-Original Gated LoRA)
(Hardcoded Structure matching Batch Inference)
"""

import os
import sys
import yaml
import torch
import random
from pathlib import Path
from PIL import Image
import numpy as np
# import argparse # Removed for hardcoding
import torch.nn.functional as F

# Add project root to path
project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))

from omini.pipeline.flux_omini_solar import Condition, generate, seed_everything
import omini.pipeline.flux_omini_solar as flux_omini_solar
from omini.train_flux.train_aircraft_bg_crop_solar import OminiSolarModel

# --- Helper Functions (Copied from Batch Script) ---

def load_config(config_path: str):
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config

def frequency_split_latents(latents, kernel_size=5):
    """Decompose latents into Low-Frequency and High-Frequency components."""
    B, L, C = latents.shape
    H = int(L**0.5)
    W = H
    x = latents.transpose(1, 2).view(B, C, H, W)
    pad = kernel_size // 2
    low = F.avg_pool2d(x, kernel_size=kernel_size, stride=1, padding=pad, count_include_pad=False)
    high = x - low
    low_flat = low.flatten(2).transpose(1, 2)
    high_flat = high.flatten(2).transpose(1, 2)
    return low_flat, high_flat

class LowFreqCondition(Condition):
    """Low Frequency Condition (Style/Lighting)"""
    def encode(self, pipe, empty=False):
        latents, ids = super().encode(pipe, empty)
        low, _ = frequency_split_latents(latents)
        return low, ids

def create_default_subject(size):
    """Create default subject image (grey square)"""
    return Image.new("RGB", size, (128, 128, 128))

def load_rgba_with_black_background(image_path: str) -> Image.Image:
    img = Image.open(image_path)
    if img.mode == 'RGB':
        return img
    if img.mode == 'RGBA':
        background = Image.new('RGB', img.size, (255, 255, 255))
        background.paste(img, mask=img.split()[3])
        return background
    return img.convert('RGB')

# --- Inference Logic ---

@torch.no_grad()
def single_inference(
    model,
    background_path,
    subject_path,
    mask_path,
    output_path,
    target_size=(512, 512),
    condition_size=(512, 512),
    seed=42
):
    # Ensure output directory exists
    output_dir = os.path.dirname(output_path)
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)

    print(f"\n{'='*70}")
    print(f"Single Inference (Solar + Low-Original Gated LoRA)")
    print(f"{'='*70}")
    print(f"Background: {background_path}")
    print(f"Subject:    {subject_path}")
    print(f"Mask:       {mask_path}")
    print(f"Output:     {output_path}")
    print(f"Seed:       {seed}")

    # Adapter Names
    adapter_low = "subject_low"
    adapter_original = "subject_original"
    adapter_bg = "background"
    
    unified_prompt = "Place an aircraft at the specified position"

    # 1. Prepare Images
    try:
        background_img = Image.open(background_path).convert("RGB").resize(condition_size, Image.BILINEAR)
        mask_img = Image.open(mask_path).convert("L").resize(target_size, Image.NEAREST)
        
        if subject_path and os.path.exists(subject_path):
            subject_img = load_rgba_with_black_background(subject_path).resize(condition_size, Image.BILINEAR)
        else:
            print(f"⚠️  Subject not found or None, using default grey image.")
            subject_img = create_default_subject(condition_size)
            
    except Exception as e:
        print(f"❌ Failed to load images: {e}")
        return

    # 2. Build Conditions
    # Subject Low Frequency (Style)
    cond_low = LowFreqCondition(subject_img, adapter_low, [-16, -32])
    # Subject Original (Structure + Style)
    cond_original = Condition(subject_img, adapter_original, [-16, -32])
    # Background
    cond_bg = Condition(background_img, adapter_bg, [16, -32])
    
    # 3. Solar Parameters Calculation
    print("  Calculating Solar Parameters...")
    # Encode background to latents
    bg_latents, _ = flux_omini_solar.encode_images(model.flux_pipe, background_img)
    
    # Unpack / Reshape latents
    B, L, C = bg_latents.shape
    H_latent = int(L ** 0.5)
    W_latent = H_latent
    bg_spatial = bg_latents.transpose(1, 2).view(B, C, H_latent, W_latent).to(torch.float32)
    
    # Prepare Mask
    mask_tensor = torch.from_numpy(np.array(mask_img)).float() / 255.0
    mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(model.device)
    
    # Run Solar Encoder
    context_vector = model.solar_encoder(bg_spatial, mask_tensor)
    
    # Run Projectors
    solar_params_list = []
    for proj in model.solar_projectors:
        params = proj(context_vector)
        scale, shift = params.chunk(2, dim=1)
        scale = scale.unsqueeze(1).to(model.dtype)
        shift = shift.unsqueeze(1).to(model.dtype)
        solar_params_list.append((scale, shift))
        
    # 4. Generate
    seed_everything(seed)
    generator = torch.Generator(device=model.device)
    generator.manual_seed(seed)
    
    print(f"  Generating image...")
    try:
        res = generate(
            model.flux_pipe,
            prompt=unified_prompt,
            conditions=[cond_low, cond_original, cond_bg], 
            height=target_size[1],
            width=target_size[0],
            num_inference_steps=28,
            guidance_scale=3.5,
            generator=generator,
            model_config=model.model_config,
            kv_cache=model.model_config.get("independent_condition", False),
            solar_params_list=solar_params_list,
        )
        
        # 5. Save Result
        res.images[0].save(output_path)
        print(f"  ✓ Saved to {output_path}")
        
    except Exception as e:
        print(f"  ❌ Generation failed: {e}")
        import traceback
        traceback.print_exc()

# --- Main Execution ---

def main():
    # --- Hardcoded Configuration ---
    """
Single Sample Inference for Aircraft (Bg + Crop + Solar + Low-Original Gated LoRA)
(Hardcoded Structure matching Batch Inference)
"""
...
# --- Main Execution ---

def main():
    # --- Hardcoded Configuration ---
    CONFIG_PATH = "./train/config/aircraft_bg_crop.yaml"
    CHECKPOINT_PATH = "runs_bg_crop/20260106-174326/ckpt/8000"
    
    # Input Paths (Please modify these)
    BACKGROUND_PATH = "augmented_results_batch/debug_conditions/304_A2_step0_bg_input.png"
    SUBJECT_PATH = "augmented_results_batch/debug_conditions/304_A2_step0_subject.png"
    MASK_PATH = "augmented_results_batch/debug_conditions/304_A2_step0_mask.png"
    OUTPUT_PATH = "single_result_solar.jpg"
    
    SEED = 42
    
    # -------------------------------
    
    # 🔧 1. Set Deterministic Mode
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    seed_everything(SEED)
    
    print("="*70)
    print("Single Inference (Hardcoded Structure)")
    print("="*70)

    # 1. Load Config
    print(f"Loading config from {CONFIG_PATH}...")
    if not os.path.exists(CONFIG_PATH):
        print(f"❌ Error: Config file not found: {CONFIG_PATH}")
        return
    config = load_config(CONFIG_PATH)
    training_config = config["train"]
    dataset_config = training_config["dataset"]
    
    target_size = tuple(dataset_config["target_size"])
    condition_size = tuple(dataset_config["condition_size"])
    
    # 2. Check Checkpoint
    if not os.path.exists(CHECKPOINT_PATH):
        print(f"❌ Error: Checkpoint not found at {CHECKPOINT_PATH}")
        # return # Allow to proceed if just testing structure, or strictly return
    
    print(f"Loading model from {CHECKPOINT_PATH}...")
    
    # 3. Initialize Model
    # (Matches batch script initialization exactly)
    model = OminiSolarModel(
        flux_pipe_id=config["flux_path"],
        lora_path=None,
        lora_config=None, 
        device="cuda",
        dtype=torch.bfloat16 if config["dtype"] == "bfloat16" else torch.float32,
        model_config=config.get("model", {}),
        adapter_names=[None, None, "subject_low", "subject_original", "background"],
        gradient_checkpointing=False,
    )
    # Restore model.adapter_set for later use
    model.adapter_set = set(["subject_low", "subject_original", "background"])
    model.training_config = training_config
    
    # 4. Load Weights (Solar Components + LoRAs)
    
    # Load Solar Components
    solar_path = os.path.join(CHECKPOINT_PATH, "solar_components.pt")
    if os.path.exists(solar_path):
        print(f"Loading Solar Components from {solar_path}")
        state = torch.load(solar_path, map_location=model.device)
        model.solar_encoder.load_state_dict(state["encoder"])
        model.solar_projectors.load_state_dict(state["projectors"])
    else:
        print(f"❌ Error: solar_components.pt not found in {CHECKPOINT_PATH}")
        return
        
    # Load LoRA adapters
    print("Loading LoRA adapters...")
    for adapter_name in ["subject_low", "subject_original", "background"]:
        lora_file = os.path.join(CHECKPOINT_PATH, f"{adapter_name}.safetensors")
        if os.path.exists(lora_file):
            print(f"  Loading {adapter_name} from {lora_file}")
            model.flux_pipe.load_lora_weights(CHECKPOINT_PATH, weight_name=f"{adapter_name}.safetensors", adapter_name=adapter_name)
        else:
            print(f"⚠️  Warning: {adapter_name}.safetensors not found.")
            
    # Explicitly activate all LoRA adapters
    print("\n  Activating LoRA adapters...")
    adapter_list = list(model.adapter_set)
    if adapter_list:
        try:
            model.transformer.set_adapters(adapter_list)
            print(f"    ✓ Activated adapters: {adapter_list}")
        except Exception as e:
            print(f"    ⚠️  set_adapters failed: {e}")
            for adapter_name in adapter_list:
                try:
                    model.transformer.enable_adapters(adapter_name)
                    print(f"    ✓ Enabled adapter: {adapter_name}")
                except Exception as e2:
                    print(f"    ⚠️  Failed to enable {adapter_name}: {e2}")

    # Set to eval mode
    model.eval()
    model.transformer.eval()
    model.flux_pipe.vae.eval()
    model.flux_pipe.text_encoder.eval()
    model.flux_pipe.text_encoder_2.eval()
    
    # Disable dropout
    for module in model.modules():
        if isinstance(module, torch.nn.Dropout):
            module.p = 0.0
            
    print("  ✓ Model loaded and set to eval mode")
    
    # 5. Run Single Inference
    single_inference(
        model,
        BACKGROUND_PATH,
        SUBJECT_PATH,
        MASK_PATH,
        OUTPUT_PATH,
        target_size=target_size,
        condition_size=condition_size,
        seed=SEED
    )
    print("Done!")

if __name__ == "__main__":
    main()