 #!/usr/bin/env python3
"""
航空目标推理脚本 - Bg + Crop + Solar
直接读取数据集进行推理
"""

import os
import sys
import yaml
import torch
import random
import numpy as np
from pathlib import Path
from PIL import Image
import torch.nn.functional as F
import torchvision.transforms as T

project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))

from omini.train_flux.train_aircraft_bg_crop_solar import AircraftSolarDataset, OminiSolarModel
from omini.pipeline.flux_omini_solar import Condition, generate, seed_everything
import omini.pipeline.flux_omini_solar as flux_omini_solar

def load_config(config_path: str):
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config

def inference_on_training_samples(
    model,
    dataset,
    num_samples: int = 10,
    output_dir: str = "inference_results_solar",
    seed: int = 42
):
    os.makedirs(output_dir, exist_ok=True)
    
    target_size = model.training_config["dataset"]["target_size"]
    
    adapter_subject = "subject"
    adapter_bg = "background"
    
    print(f"\n{'='*70}")
    print(f"Inference on {num_samples} training samples (Solar)")
    print(f"{'='*70}")
    
    num_samples = min(num_samples, len(dataset))
    pil_dataset = dataset
    original_return_pil = pil_dataset.return_pil_image
    pil_dataset.return_pil_image = True
    
    with torch.no_grad():
        for idx in range(num_samples):
            print(f"\n[{idx+1}/{num_samples}] Processing sample {idx}...")
            
            sample = pil_dataset[idx]
            subject_img = sample["condition_0"]
            if isinstance(subject_img, torch.Tensor):
                subject_img = T.ToPILImage()(subject_img)

            background_img = sample["condition_1"]
            if isinstance(background_img, torch.Tensor):
                background_img = T.ToPILImage()(background_img)
            prompt = sample["description"]
            
            mask_img = None
            if "target_mask" in sample and isinstance(sample["target_mask"], Image.Image):
                mask_img = sample["target_mask"]
            elif "mask" in sample:
                 # Check if it's a path or an image object
                 if isinstance(sample["mask"], str):
                     try:
                        mask_img = Image.open(sample["mask"]).convert("L")
                     except Exception as e:
                        print(f"  ⚠️  Failed to load mask from path: {e}")
                 elif isinstance(sample["mask"], Image.Image):
                     mask_img = sample["mask"]
            
            if mask_img is None:
                # Fallback: Try to load from dataset if possible or use dummy
                # But here we are iterating over dataset, so sample should have everything.
                # If return_pil_image=True, AircraftSolarDataset returns "target_mask" as Tensor!
                # Wait, AircraftMaskWeightedDataset.to_tensor is called in __getitem__
                # Let's check __getitem__ of AircraftSolarDataset
                
                # In AircraftSolarDataset.__getitem__:
                # mask_tensor = self.to_tensor(real_mask)
                # return { ..., "target_mask": mask_tensor }
                
                # If return_pil_image=True (which we set), the parent class logic might differ?
                # AircraftMaskWeightedDataset.__getitem__ handles return_pil_image:
                # if self.return_pil_image:
                #     return { ..., "target_mask": mask_tensor, ... }
                # So "target_mask" is ALWAYS a Tensor, even if return_pil_image=True!
                
                # We need to convert Tensor back to PIL Image or use the original mask path if available.
                # But we don't have the path easily here in the sample dict.
                
                # Check if we have "target_mask" as tensor
                if "target_mask" in sample and isinstance(sample["target_mask"], torch.Tensor):
                    # Convert [1, H, W] tensor to PIL
                    m = sample["target_mask"]
                    if m.dim() == 3: m = m.squeeze(0)
                    mask_img = T.ToPILImage()(m)
                else:
                    print("  ⚠️  Warning: Mask not found, using dummy.")
                    mask_img = Image.new("L", target_size, 0)
                    w, h = target_size
                    mask_img.paste(255, (w//4, h//4, w*3//4, h*3//4))
            
            cond_subject = Condition(subject_img, adapter_subject, [-16, -32])
            cond_bg = Condition(background_img, adapter_bg, [16, -32])
            
            bg_latents, _ = flux_omini_solar.encode_images(model.flux_pipe, background_img)
            B, L, C = bg_latents.shape
            H_latent = int(L ** 0.5)
            W_latent = H_latent
            bg_spatial = bg_latents.transpose(1, 2).view(B, C, H_latent, W_latent).to(torch.float32)
            
            # Ensure mask_tensor is 4D: [B, C, H, W]
            # Currently: mask_tensor = torch.from_numpy(...) -> [H, W] or [H, W, C] depending on mask
            # If mask is 'L' mode, np.array is [H, W].
            
            # Reset mask_tensor from scratch to be safe
            mask_np = np.array(mask_img.resize(target_size, Image.NEAREST))
            mask_tensor = torch.from_numpy(mask_np).float() / 255.0
            
            if mask_tensor.dim() == 2: # [H, W]
                mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
            elif mask_tensor.dim() == 3: # [H, W, C] or [C, H, W] - unlikely for 'L' mode but possible
                # If [H, W, 1], permute to [1, H, W] then unsqueeze
                if mask_tensor.shape[2] == 1:
                    mask_tensor = mask_tensor.permute(2, 0, 1).unsqueeze(0)
                else:
                    mask_tensor = mask_tensor.unsqueeze(0) # Assume [C, H, W] -> [1, C, H, W]
            
            mask_tensor = mask_tensor.to("cuda")
            
            # Force both to CUDA for solar encoder
            bg_spatial = bg_spatial.to("cuda")
            
            context_vector = model.solar_encoder(bg_spatial, mask_tensor)
            
            solar_params_list = []
            for proj in model.solar_projectors:
                params = proj(context_vector)
                scale, shift = params.chunk(2, dim=1)
                scale = scale.unsqueeze(1).to(model.dtype)
                shift = shift.unsqueeze(1).to(model.dtype)
                solar_params_list.append((scale, shift))
            
            generator = torch.Generator(device=model.flux_pipe.device)
            generator.manual_seed(seed + idx)
            
            print(f"  Generating image...")
            try:
                res = generate(
                    model.flux_pipe,
                    prompt=prompt,
                    conditions=[cond_subject, cond_bg],
                    height=target_size[1],
                    width=target_size[0],
                    num_inference_steps=28,
                    guidance_scale=3.5,
                    generator=generator,
                    model_config=model.model_config,
                    kv_cache=model.model_config.get("independent_condition", False),
                    solar_params_list=solar_params_list,
                )
                
                output_path = os.path.join(output_dir, f"sample_{idx}_generated.jpg")
                res.images[0].save(output_path)
                print(f"  ✓ Saved to {output_path}")
                
                condition_dir = os.path.join(output_dir, "conditions")
                os.makedirs(condition_dir, exist_ok=True)
                subject_img.save(os.path.join(condition_dir, f"sample_{idx}_subject.jpg"))
                background_img.save(os.path.join(condition_dir, f"sample_{idx}_background.jpg"))
                mask_img.save(os.path.join(condition_dir, f"sample_{idx}_mask.jpg"))
                original_img = sample["image"]
                if isinstance(original_img, torch.Tensor):
                    original_img = T.ToPILImage()(original_img)
                original_img.save(os.path.join(condition_dir, f"sample_{idx}_original.jpg"))
                    
            except Exception as e:
                print(f"  ❌ Generation failed: {e}")
                import traceback
                traceback.print_exc()
                continue
                
    pil_dataset.return_pil_image = original_return_pil
    print(f"\n{'='*70}")
    print(f"✓ Inference completed! Results saved to {output_dir}")

def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    seed_everything(seed)
    
    config_path = os.environ.get("OMINI_CONFIG", "./train/config/aircraft_bg_crop.yaml")
    print(f"Loading config from {config_path}")
    config = load_config(config_path)
    training_config = config["train"]
    dataset_config = training_config["dataset"]
    
    print("Loading dataset...")
    dataset = AircraftSolarDataset(
        dataset_root=dataset_config["dataset_root"],
        condition_size=tuple(dataset_config["condition_size"]),
        target_size=tuple(dataset_config["target_size"]),
        drop_text_prob=0.0,
        drop_subject_prob=0.0,
        drop_position_prob=0.0,
        drop_background_prob=0.0,
        augmentation_prob=0.0,
        return_pil_image=True
    )
    
    checkpoint_path = "runs_bg_crop/20260104-200004/ckpt/10000"
    print(f"Loading model from {checkpoint_path}...")
    
    model = OminiSolarModel(
        flux_pipe_id=config["flux_path"],
        lora_path=None,
        lora_config=None,
        device="cuda",
        dtype=torch.bfloat16 if config["dtype"] == "bfloat16" else torch.float32,
        model_config=config.get("model", {}),
        adapter_names=[None, None, "subject", "background"],
        gradient_checkpointing=False,
    )
    # Explicitly move model to cuda to ensure all submodules (including solar_encoder) are on GPU
    model = model.to("cuda")
    
    model.adapter_set = set(["subject", "background"])
    model.training_config = training_config
    
    solar_path = os.path.join(checkpoint_path, "solar_components.pt")
    if os.path.exists(solar_path):
        state = torch.load(solar_path, map_location=model.flux_pipe.device)
        model.solar_encoder.load_state_dict(state["encoder"])
        model.solar_projectors.load_state_dict(state["projectors"])
    
    for adapter_name in ["subject", "background"]:
        model.flux_pipe.load_lora_weights(checkpoint_path, weight_name=f"{adapter_name}.safetensors", adapter_name=adapter_name)
    
    adapter_list = list(model.adapter_set)
    model.transformer.set_adapters(adapter_list)
    model.eval()
    
    inference_on_training_samples(model, dataset, num_samples=len(dataset))

if __name__ == "__main__":
    main()