import os
import glob
from typing import Tuple, List, Dict, Union
import argparse

import numpy as np
import torch
from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
from transformers import AutoModelForImageSegmentation

# jigsaw3D
# attn set：
from jigsaw3D.models.attention_processor import Jigsaw3DAttnProcessor

# pipeline set:
from jigsaw3D.pipelines.pipeline_jigsaw3D_i2mv_sdxl import Jigsaw3DI2MVSDXLPipeline

from jigsaw3D.schedulers.scheduling_shift_snr import ShiftSNRScheduler
from jigsaw3D.utils import get_orthogonal_camera, make_image_grid, tensor_to_image
from jigsaw3D.utils.render import NVDiffRastContextWrapper, load_mesh, render

from jigsaw3D.utils.jigsaw import apply_jigsaw_mask, apply_content_block_jigsaw
from jigsaw3D.utils.load_mesh import load_objaverse_mesh

import ast
from PIL import ImageFont, ImageDraw


def remove_bg_simple(image, net, transform, device, target_height=None, target_width=None):
    """
    Remove background and fill with gray, resize to target dimensions
    
    Args:
        image: Input image
        net: Background removal model
        transform: Image transformation function
        device: Device
        target_height: Target height (optional)
        target_width: Target width (optional)
    
    Returns:
        Image.Image: Processed image
    """
    # First use remove_bg function to get image with transparent background
    rgba_image = remove_bg(image, net, transform, device)
    
    # If target dimensions are not specified, use original dimensions
    if target_height is None or target_width is None:
        target_width, target_height = rgba_image.size
    
    # Resize image to target dimensions
    resized_rgba = rgba_image.resize((target_width, target_height), Image.LANCZOS)
    
    # Create a gray background image
    gray_bg = Image.new('RGB', (target_width, target_height), color=(128, 128, 128))
    
    # Composite RGBA image onto gray background
    result_image = Image.alpha_composite(
        gray_bg.convert('RGBA'),  # Convert gray background to RGBA
        resized_rgba  # Foreground image with Alpha channel
    ).convert('RGB')  # Final conversion back to RGB format
    
    return result_image


def remove_bg(image, net, transform, device):
    """
    Remove background, return image with transparent channel
    """
    # Ensure image is converted to RGB mode (remove Alpha channel)
    if image.mode == 'RGBA':
        image_rgb = image.convert('RGB')
    else:
        image_rgb = image
        
    image_size = image.size
    # Use RGB image for transformation
    input_images = transform(image_rgb).unsqueeze(0).to(device)
    
    with torch.no_grad():
        preds = net(input_images)[-1].sigmoid().cpu()
    
    pred = preds[0].squeeze()
    
    # Assertion check: pred should be single channel (grayscale image)
    assert len(pred.shape) == 2, f"Expected pred to be 2D (H, W), but got shape {pred.shape}"
    
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    
    # Create result copy to keep original image unchanged
    result = image.copy()
    result.putalpha(mask)
    return result


def preprocess_image(image: Image.Image, height, width):
    # Ensure RGBA channels are preserved
    if image.mode != 'RGBA':
        image = image.convert('RGBA')
    
    # 1. Directly scale entire image (including transparent areas) to target dimensions
    resized_image = image.resize((width, height), Image.LANCZOS)
    
    # 2. Convert to numpy array for processing
    image_np = np.array(resized_image).astype(np.float32) / 255.0
    
    # 3. Separate RGB and Alpha channels
    rgb = image_np[..., :3]
    alpha = image_np[..., 3:4]
    
    # 4. Replace transparent areas with gray (0.5)
    # Formula: RGB * Alpha + (1 - Alpha) * background color
    result = rgb * alpha + (1 - alpha) * 0.5
    
    # 5. Convert back to PIL image
    result = (result * 255).clip(0, 255).astype(np.uint8)
    return Image.fromarray(result)


def preprocess_image_scale_and_pad(
    image: Image.Image,
    height: int,
    width: int,
    scale: Union[float, int] = 0.8
) -> Tuple[Image.Image, Image.Image]:
    """
    Preprocess reference image:
    1. Scale input image by ratio or fixed pixel count
    2. Center on gray background
    
    Returns:
        Tuple[Image.Image, Image.Image]: 
            - First is processed original image (same processing as padded image)
            - Second is scaled and padded image
    """

    # Parameter validation and dimension calculation (remain unchanged)
    if isinstance(scale, int):
        if scale < 0 or scale >= min(height, width)//2:
            raise ValueError(f"Padding pixels should be in range (0, {min(height, width)//2})")
        new_height = height - 2 * scale
        new_width = width - 2 * scale
    else:
        if not 0 < scale < 1:
            raise ValueError("Scale ratio should be in range (0, 1)")
        new_height = int(height * scale)
        new_width = int(width * scale)

    # Ensure input image is in RGBA format
    if image.mode != 'RGBA':
        image = image.convert('RGBA')
    
    # Process original image (same processing as padded image)
    original_array = np.array(image.resize((width, height), Image.LANCZOS)).astype(np.float32) / 255.0
    original_array[:, :, :3] = (
        original_array[:, :, :3] * original_array[:, :, 3:4] + 
        (1 - original_array[:, :, 3:4]) * 0.5
    )
    original_image = Image.fromarray((original_array * 255).clip(0, 255).astype(np.uint8)[:, :, :3])

    # If padding is 0, directly use original image processing method
    if isinstance(scale, int) and scale == 0:
        return original_image, original_image.copy()

    # Scale image
    resized_image = image.resize((new_width, new_height), Image.LANCZOS)
    resized_array = np.array(resized_image)

    # Create padded image
    padded_image = np.zeros((height, width, 4), dtype=np.uint8)
    start_h = (height - new_height) // 2
    start_w = (width - new_width) // 2
    padded_image[start_h:start_h+new_height, start_w:start_w+new_width] = resized_array
    
    # Process padded image
    padded_image = padded_image.astype(np.float32) / 255.0
    padded_image[:, :, :3] = (
        padded_image[:, :, :3] * padded_image[:, :, 3:4] + 
        (1 - padded_image[:, :, 3:4]) * 0.5
    )
    padded_image = Image.fromarray((padded_image * 255).clip(0, 255).astype(np.uint8)[:, :, :3])

    return original_image, padded_image

def run_pipeline(
    pipe,
    mesh_path,
    num_views,
    text,
    image,
    height,
    width,
    num_inference_steps,
    guidance_scale,
    seed,
    remove_bg_fn=None,
    reference_conditioning_scale=1.0,
    negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
    lora_scale=1.0,
    device="cuda",
    img_patch_size=128,
    mask_ratio=0.0,
    vae_patch_size=16,
    is_image_shuffle=False,
    is_vae_shuffle=False,
    output_type="pt",  # Ensure pipeline returns pt, or PIL images
    bg_remove_or_padding='bg_padding',
):
    # Hardcoded here, note!
    print('------- mesh_path --------', mesh_path)
    if "Objaverse-Ortho" in mesh_path:
        control_images = load_objaverse_mesh(mesh_path).to(device)
        # Output result verification
        print(f"Control images shape: {control_images.shape}")  # torch.Size([6, 6, 512, 512])
        print(f"Data range - Min: {control_images.min().item():.4f}, Max: {control_images.max().item():.4f}")
        
        # Extract position and normal information from control_images
        # control_images shape is [6, 6, 512, 512]
        # First 3 channels are positions, last 3 channels are normals
        pos_part = control_images[:, :3, :, :]  # [6, 3, 512, 512]
        normal_part = control_images[:, 3:, :, :]  # [6, 3, 512, 512]
        
        # Convert position and normal parts to image format
        # First convert tensor from [B, C, H, W] -> [B, H, W, C]
        pos_images = pos_part.permute(0, 2, 3, 1).cpu().numpy()
        normal_images = normal_part.permute(0, 2, 3, 1).cpu().numpy()
        
        # Convert to PIL images
        pos_images = [Image.fromarray((img * 255).astype(np.uint8)) for img in pos_images]
        normal_images = [Image.fromarray((img * 255).astype(np.uint8)) for img in normal_images]
        
    else:
        # Prepare cameras
        cameras = get_orthogonal_camera(
            elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
            distance=[1.8] * num_views,
            left=-0.55,
            right=0.55,
            bottom=-0.55,
            top=0.55,
            azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
            device=device,
        )

        # ctx = NVDiffRastContextWrapper(device=device)
        # use CUDA
        ctx = NVDiffRastContextWrapper(device=device, context_type="cuda")

        # mesh = load_mesh(mesh_path, rescale=True, device=device)
        mesh = load_mesh(mesh_path, move_to_center=True, rescale=True, device=device)
        render_out = render(
            ctx,
            mesh,
            cameras,
            height=height,
            width=width,
            render_attr=False,
            normal_background=0.0,
        )
        pos_images = tensor_to_image((render_out.pos + 0.5).clamp(0, 1), batched=True)
        normal_images = tensor_to_image(
            (render_out.normal / 2 + 0.5).clamp(0, 1), batched=True
        )
        control_images = (
            torch.cat(
                [
                    (render_out.pos + 0.5).clamp(0, 1),
                    (render_out.normal / 2 + 0.5).clamp(0, 1),
                ],
                dim=-1,
            )
            .permute(0, 3, 1, 2)
            .to(device)
        )

    # Prepare image
    reference_image = Image.open(image) if isinstance(image, str) else image

    if bg_remove_or_padding == 'bg_padding':
        ################################# Option 1: Process image - use fixed pixel to pad background ###############################
        original_reference_image, reference_image = preprocess_image_scale_and_pad(reference_image, height, width, 64) # setting padding size
    elif bg_remove_or_padding == 'bg_remove':
        ################################# Option 2: Process image - Remove background ###############################
        original_reference_image = reference_image.copy()
        print("======= Loading background removal model ========")
        birefnet = AutoModelForImageSegmentation.from_pretrained(
            "/root/autodl-tmp/style-pretrained-ckpts/BiRefNet", 
            trust_remote_code=True
        )
        birefnet.to(device)
        transform_image = transforms.Compose([
            transforms.Resize((1024, 1024)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
    
        # Remove background and resize to target dimensions
        reference_image = remove_bg_simple(reference_image, birefnet, transform_image, device, height, width)
        reference_image.save("reference_image_processed.png")
        print("Image saved as reference_image_processed.png")
    else:
        raise ValueError("Invalid bg_remove_or_padding parameter. Must be either 'bg_padding' or 'bg_remove'")
        
        
    original_padded_reference_image = reference_image.copy()

    # image shuffle/ image jigsaw
    if is_image_shuffle:
        # convert PIL Image to PyTorch Tensor
        transform_to_tensor = transforms.ToTensor()
        reference_tensor = transform_to_tensor(reference_image).unsqueeze(0)  # [1, C, H, W]
        jigsaw_tensor = apply_jigsaw_mask(
            reference_tensor, 
            patch_size=img_patch_size,
            mask_ratio=mask_ratio,
            mask_type='random', # Mask type（random/center/grid）
            shuffle=is_image_shuffle,
        )

        # convert to PIL Image
        transform_to_pil = transforms.ToPILImage()
        reference_image = transform_to_pil(jigsaw_tensor.squeeze(0)) 

    
    pipe_kwargs = {}
    if seed != -1 and isinstance(seed, int):
        pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed)

    images = pipe(
        text,
        height=height,
        width=width,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        num_images_per_prompt=num_views,
        control_image=control_images,
        control_conditioning_scale=1.0,
        reference_image=reference_image,
        reference_conditioning_scale=reference_conditioning_scale,
        negative_prompt=negative_prompt,
        cross_attention_kwargs={"scale": lora_scale},
        output_type=output_type, # modify here
        vae_feature_shuffle=is_vae_shuffle,
        vae_patch_size=vae_patch_size,
        **pipe_kwargs,
    ).images

    # return images, pos_images, normal_images, reference_image
    return images, pos_images, normal_images, original_reference_image, original_padded_reference_image, reference_image


if __name__ == "__main__":

    print(f"all files done and saved into {args.output_dir}")