from diffusers import DiffusionPipeline
import torch
import os
import sys
import argparse
import cv2
import numpy as np

# Add parent directory to path to import models
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)

from models.layer_aware_dual_lora import LayerUNet
from diffusers import UNet2DConditionModel
from PIL import Image



def main():
    # 命令行参数解析
    parser = argparse.ArgumentParser(description='Single image object removal inference')
    parser.add_argument('--image', '-i',  help='Path to input image')
    parser.add_argument('--mask', '-m', help='Path to mask image')
    parser.add_argument('--output', '-o', help='Path to output image')
    parser.add_argument('--checkpoint', '-c', help='Path to checkpoint directory')

    args = parser.parse_args()
    

    if not os.path.exists(args.image):
        print(f"❌ Input image not found: {args.image}")
        return
    
    if not os.path.exists(args.mask):
        print(f"❌ Mask image not found: {args.mask}")
        return
    

    output_dir = os.path.dirname(args.output)
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
    
    # Check CUDA availability
    print("CUDA available:", torch.cuda.is_available())
    print("Current device:", torch.cuda.current_device())
    print("Device name:", torch.cuda.get_device_name(0))
    
    # Model paths
    pretrain_path = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
    pretrained_unet_path = ""
    
    print(f"Loading checkpoint from {args.checkpoint}")
    
    # Load the original unet first
    print(f"Loading pretrained UNet from {pretrained_unet_path}")
    if os.path.exists(pretrained_unet_path):
        original_unet = UNet2DConditionModel.from_pretrained(pretrained_unet_path, torch_dtype=torch.float16)
        print("✅ Loaded pretrained UNet")
    else:
        print("⚠️  Pretrained UNet not found, using base SDXL Inpaint UNet")
        original_unet = UNet2DConditionModel.from_pretrained(pretrain_path, subfolder="unet", torch_dtype=torch.float16)
    

    unet = LayerUNet(
        original_unet, 
        lora_rank=16,
        lora_alpha=16.0,
        use_mask_aware=True
    )
    
    # Load the trained LoRA weights
    lora_weights_loaded = False
    alpha_weights_loaded = False
    

    lora_path = os.path.join(args.checkpoint, "lora_weights.pth")
    if os.path.exists(lora_path):
        try:
            unet.load_lora_weights(lora_path)
            print(f"✅ Loaded LoRA weights from {lora_path}")
            lora_weights_loaded = True
        except Exception as e:
            print(f"❌ Failed to load LoRA weights from {lora_path}: {e}")
    else:
        print(f"❌ LoRA weights not found: {lora_path}")
    
    if not lora_weights_loaded:
        print("❌ No LoRA weights found, cannot continue inference without LoRA weights")
        return
    

    if unet.use_mask_aware:
        alpha_path = os.path.join(args.checkpoint, "lora_weights_alpha.pth")
        if os.path.exists(alpha_path):
            try:
                unet.load_alpha_weights(alpha_path)
                print(f"✅ Loaded alpha weights from {alpha_path}")
                alpha_weights_loaded = True
            except Exception as e:
                print(f"❌ Failed to load alpha weights from {alpha_path}: {e}")
        else:
            print(f"⚠️  No alpha weights found: {alpha_path}, using default values")
            unet.reset_alpha_to_default()
    

    if hasattr(unet, 'get_alpha_summary'):
        alpha_summary = unet.get_alpha_summary()
        print(f"Alpha Summary: {alpha_summary}")
    
    unet = unet.to("cuda", dtype=torch.float16)
    print("Finished loading LayerUNet!")
    

    pipe_edit = DiffusionPipeline.from_pretrained(
        pretrain_path,
        custom_pipeline="dual_lora_inpaint_pipeline.py",
        unet=unet,
        torch_dtype=torch.float16,
        variant="fp16"
    ).to("cuda")
    
    pipe_edit.config.allow_cpu_offload = False
    print("Pipeline device:", next(pipe_edit.unet.parameters()).device)
    

    print(f"Processing image: {args.image}")
    
    try:

        original_img = Image.open(args.image).convert("RGB")
        mask = Image.open(args.mask).convert("L")
        original_size = original_img.size
        print(f"Original image size: {original_size}")


        from torchvision.transforms.functional import to_tensor, gaussian_blur
        import torch.nn.functional as F

        img_tensor = to_tensor(original_img).unsqueeze(0).float() * 2 - 1  
        if img_tensor.shape[1] != 3:
            img_tensor = img_tensor.expand(-1, 3, -1, -1)
        img_tensor = F.interpolate(img_tensor, (512, 512))
        img_tensor = img_tensor.to(torch.float16).to("cuda")

        mask_tensor = to_tensor(mask).unsqueeze(0).float()
        mask_tensor = F.interpolate(mask_tensor, (512, 512))
        mask_tensor = gaussian_blur(mask_tensor, kernel_size=(77, 77))
        mask_tensor[mask_tensor < 0.1] = 0
        mask_tensor[mask_tensor >= 0.1] = 1
        mask_tensor = mask_tensor.to(torch.float16).to("cuda")

        prompt = ""
        print("Starting inference...")
        bg_result = pipe_edit(
            prompt=prompt,
            image=img_tensor,
            mask_image=mask_tensor,
            height=512,
            width=512,
            guidance_scale=7.5,
            num_inference_steps=20,
            strength=0.99,
            branch="background",
        ).images[0]

        bg_result = bg_result.resize(original_size, Image.Resampling.LANCZOS)
        bg_result.save(args.output)
        print(f"Successfully saved result to: {args.output}")
        print(f"Output image size: {bg_result.size}")
        
    except Exception as e:
        print(f"Error processing image: {str(e)}")
        import traceback
        print(f"Traceback: {traceback.format_exc()}")
        return
    


if __name__ == "__main__":
    main() 