from PIL import Image
import tempfile
import argparse
import os
import sys
import glob
import time
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel

# 新增的导入
from jigsaw3D.pipelines.pipeline_jigsaw3D_i2mv_sdxl import Jigsaw3DI2MVSDXLPipeline
from jigsaw3D.models.attention_processor import Jigsaw3DAttnProcessor
from jigsaw3D.schedulers.scheduling_shift_snr import ShiftSNRScheduler

from jigsaw3D.pipelines.pipeline_texture import ModProcessConfig, TexturePipeline
from jigsaw3D.utils import make_image_grid


def create_pipeline_from_deepspeed_checkpoint(
    device, 
    dtype, 
    num_views=6, 
    scheduler="ddpm", 
    lora_model=None,
    base_model=None,
    vae_model=None,
    use_which_ckpt=None,
    jigsaw3D_path=None,
):
    ################################# load ref-unet here: ##########################
    ref_pipe_kwargs = {}
    if vae_model is not None:
        ref_pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model)
    
    # Prepare ref-pipeline
    ref_pipe: Jigsaw3DI2MVSDXLPipeline
    ref_pipe = Jigsaw3DI2MVSDXLPipeline.from_pretrained(base_model, **ref_pipe_kwargs)

    ref_pipe.init_custom_adapter(
        num_views=num_views, self_attn_processor=Jigsaw3DAttnProcessor
    )

    ref_unet = ref_pipe.unet
    # Ensure ref_unet parameters are fixed
    ref_unet.requires_grad_(False)
    ref_unet.eval()
    ################################# load ref-unet  ##########################

    # ''' You will need to change the checkpoint name here '''
    # Define checkpoint paths based on selected config
    # checkpoint_map = {
    #     'image_shuffle_patch_size_64': "../jigsaw3D_ckpt/important/mp_rank_00_model_states.pt",
    # }
    # if use_which_ckpt not in checkpoint_map:
    #     raise ValueError(f"Unknown checkpoint type: {use_which_ckpt}")
    
    # checkpoint_path = checkpoint_map[use_which_ckpt]
    
    checkpoint_path = jigsaw3D_path

    ###############################################################
    # Load DeepSpeed checkpoint
    # print(f"Loading checkpoint from: {checkpoint_path}")
    # checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
    # print(f"Checkpoint loaded, total keys: {len(checkpoint)}")
    
    ########################### jigsaw3D import ####################################
    # Dynamically map all relevant old module paths to new ones
    # Map the base module
    sys.modules['mvadapter'] = sys.modules['jigsaw3D']
    
    # Map specific submodule paths
    # Assuming your new file structure is: jigsaw3D/systems/jigsaw3D_image_sdxl.py
    # You need to map the old mvadapter.systems.mvadapter_image_sdxl_single_albedo_branch to the new module
    try:
        from jigsaw3D.systems import jigsaw3D_image_sdxl
        # Map the full old module path to the new module
        sys.modules['mvadapter.systems.mvadapter_image_sdxl_single_albedo_branch'] = jigsaw3D_image_sdxl
        print("Module mapping successful: mvadapter.systems.mvadapter_image_sdxl_single_albedo_branch -> jigsaw3D_image_sdxl")
    except ImportError as e:
        print(f"Import error during module mapping: {e}")
        # If import fails, create a dummy mapping
        sys.modules['mvadapter.systems.mvadapter_image_sdxl_single_albedo_branch'] = sys.modules['jigsaw3D']
    
    # Load the checkpoint
    print(f"Loading checkpoint from: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
    print(f"Checkpoint loaded, total keys: {len(checkpoint)}")

    
    # Extract model state
    if 'module' in checkpoint:
        model_state = checkpoint['module']
        print("Found 'module' key, using its contents as model state")
    else:
        model_state = checkpoint
        print("No 'module' key found, using entire checkpoint as model state")

    print(f"Model state dictionary keys count: {len(model_state)}")

    # Separate component states
    component_states = {
        'ref_unet': {},
        'vae': {},
        'text_encoder': {},
        'text_encoder_2': {},
        'cond_encoder': {},
        'unet': {},
    }

    # Separate parameters by component prefix
    print("\nSeparating parameters by component...")
    total_keys = 0
    for key, param in model_state.items():
        total_keys += 1
        if key.startswith('ref_unet.'):
            new_key = key.replace('ref_unet.', '')
            component_states['ref_unet'][new_key] = param
        elif key.startswith('vae.'):
            new_key = key.replace('vae.', '')
            component_states['vae'][new_key] = param
        elif key.startswith('text_encoder.'):
            new_key = key.replace('text_encoder.', '')
            component_states['text_encoder'][new_key] = param
        elif key.startswith('text_encoder_2.'):
            new_key = key.replace('text_encoder_2.', '')
            component_states['text_encoder_2'][new_key] = param
        elif key.startswith('cond_encoder.'):
            new_key = key.replace('cond_encoder.', '')
            component_states['cond_encoder'][new_key] = param
        elif key.startswith('unet.'):
            new_key = key.replace('unet.', '')
            component_states['unet'][new_key] = param
        else:
            print(f"! Unmatched key: {key}")

    print(f"Processed {total_keys} parameters")

    # Convert to target dtype
    print(f"\nConverting state dict to training precision: {dtype}")
    for comp_name, state_dict in component_states.items():
        if state_dict:
            for key in list(state_dict.keys()):
                if state_dict[key].is_floating_point():
                    state_dict[key] = state_dict[key].to(dtype)
            
            # Print conversion info
            first_key = next(iter(state_dict.keys()))
            print(f"{comp_name}: First param {first_key} type: {state_dict[first_key].dtype}")
        else:
            print(f"{comp_name}: No parameters to convert")

    # Ensure ref_unet uses same dtype
    ref_unet = ref_unet.to(dtype)
    print(f"ref_unet dtype set to: {ref_unet.dtype}")

    # Create main pipeline
    pipe_kwargs = {}
    if vae_model is not None:
        pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model)

    ########################## Prepare pipeline #######################
    pipe: Jigsaw3DI2MVSDXLPipeline
    pipe = Jigsaw3DI2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs, ref_unet=ref_unet)
    
    pipe.init_custom_adapter(
        num_views=num_views, self_attn_processor=Jigsaw3DAttnProcessor
    )

    # Load and verify each component
    def load_and_verify_state_dict(model, state_dict, component_name):
        """Load state dict and verify integrity"""
        if not state_dict:
            print(f"Warning: {component_name} state dict is empty, skipping")
            return
        
        print(f"\nLoading state dict for {component_name}...")
        
        # 1. Check keys match model
        model_keys = set(model.state_dict().keys())
        state_keys = set(state_dict.keys())
        
        missing_keys = model_keys - state_keys
        unexpected_keys = state_keys - model_keys
        
        if missing_keys:
            print(f"! {component_name} has {len(missing_keys)} missing parameters:")
            print(f"  Example missing: {list(missing_keys)[:3]}...")
        
        if unexpected_keys:
            print(f"! {component_name} has {len(unexpected_keys)} unexpected parameters:")
            print(f"  Example unexpected: {list(unexpected_keys)[:3]}...")
        
        # 2. Load state dict
        try:
            state_dict = {k: v.to(dtype=model.dtype) for k, v in state_dict.items()}
            model.load_state_dict(state_dict)
            print(f"√ {component_name} state dict loaded (strict mode)")
        except RuntimeError as e:
            print(f"!! {component_name} load error: {str(e)}")
            print("Trying non-strict mode loading...")
            model.load_state_dict(state_dict, strict=False)
            print(f"√ {component_name} state dict loaded (non-strict mode)")
        
        # 3. Validate parameter values
        if state_dict:
            sample_key = next(iter(state_dict.keys()))
            try:
                model_param = model.state_dict()[sample_key]
                saved_param = state_dict[sample_key]
                
                if model_param.shape != saved_param.shape:
                    print(f"!! {component_name} shape mismatch: {sample_key}")
                    print(f"  Model: {model_param.shape}, Saved: {saved_param.shape}")
                else:
                    model_param = model_param.cpu().float()
                    saved_param = saved_param.cpu().float()
                    
                    # Skip numerical verification for very large tensors
                    if model_param.numel() > 1000000:
                        print(f"  Skipping numerical check for large tensor: {sample_key}")
                    elif torch.allclose(model_param, saved_param, atol=1e-4):
                        print(f"√ {component_name} param {sample_key} matches")
                    else:
                        print(f"!! {component_name} param {sample_key} numerical mismatch")
                        diff = torch.abs(model_param - saved_param).max().item()
                        print(f"  Max difference: {diff:.6f}")
            except KeyError:
                print(f"!! {component_name} sample key {sample_key} not found in model")

    # Load all components
    load_and_verify_state_dict(pipe.ref_unet, component_states['ref_unet'], 'ref_unet')
    load_and_verify_state_dict(pipe.vae, component_states['vae'], 'vae')
    load_and_verify_state_dict(pipe.text_encoder, component_states['text_encoder'], 'text_encoder')
    load_and_verify_state_dict(pipe.text_encoder_2, component_states['text_encoder_2'], 'text_encoder_2')
    load_and_verify_state_dict(pipe.cond_encoder, component_states['cond_encoder'], 'cond_encoder')
    load_and_verify_state_dict(pipe.unet, component_states['unet'], 'unet')

    # Setup ref_unet
    pipe.ref_unet.to(device=device, dtype=dtype)
    pipe.ref_unet.eval()
    pipe.ref_unet.requires_grad_(False)
    print("ref_unet set to device and dtype")

    # Configure scheduler
    scheduler_class = None
    if scheduler == "ddpm":
        scheduler_class = DDPMScheduler
    elif scheduler == "lcm":
        scheduler_class = LCMScheduler

    pipe.scheduler = ShiftSNRScheduler.from_scheduler(
        pipe.scheduler,
        shift_mode="interpolated",
        shift_scale=8.0,
        scheduler_class=scheduler_class,
    )
    
    # Move entire pipeline to device
    pipe.to(device=device, dtype=dtype)
    pipe.cond_encoder.to(device=device, dtype=dtype)

    # Load LoRA weights if provided
    if lora_model is not None:
        model_, name_ = lora_model.rsplit("/", 1)
        pipe.load_lora_weights(model_, weight_name=name_)

    return pipe


def process_single_pair(args, mesh_path, image_path, save_name, pipe=None, texture_pipe=None, birefnet=None, transform_image=None, 
                       height=512, width=512, num_views=6, uv_size=2048, bg_remove_or_padding='bg_padding'):
    """Process a single mesh and image combination"""
    # Check if in exclusion list
    mesh_filename = os.path.basename(mesh_path)
    if mesh_filename in EXCLUDE_MESHES:
        print(f"Skipping mesh in exclusion list: {mesh_filename}")
        return False

    # Set output paths
    output_shaded_model = os.path.join(args.save_dir, f"{save_name}_shaded.glb")
    output_mv_image = os.path.join(args.save_dir, f"{save_name}.png")  # New: permanent save path for multi-view image

    # Check if skipping existing outputs
    if args.skip_existing and os.path.exists(output_shaded_model):
        print(f"Output file already exists: {output_shaded_model}, skipping processing")
        return True
    
    print(f"Processing: Mesh={os.path.basename(mesh_path)}, Image={os.path.basename(image_path)}")
    start_time = time.time()

    # to generate multi-view images
    images, _, _, _, _, _ = run_pipeline(
        pipe,
        mesh_path=mesh_path,
        num_views=num_views,
        text=args.text,
        image=image_path,
        height=height,
        width=width,
        num_inference_steps=50,
        guidance_scale=3.0,
        seed=args.seed,
        reference_conditioning_scale=args.reference_conditioning_scale,
        negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
        device=args.device,
        remove_bg_fn=remove_bg_fn,
        img_patch_size=args.img_patch_size,
        mask_ratio=args.mask_ratio,
        vae_patch_size=args.vae_patch_size,
        is_image_shuffle=args.is_image_shuffle,
        # is_image_shuffle=False,
        is_vae_shuffle=args.is_vae_shuffle,
        output_type="pil",
        bg_remove_or_padding=bg_remove_or_padding
    )
    make_image_grid(images, rows=1).save(output_mv_image)

    # Back-project and complete texture
    out = texture_pipe(
        mesh_path=mesh_path,
        save_dir=args.save_dir,
        save_name=save_name,
        uv_unwarp=True,
        preprocess_mesh=args.preprocess_mesh,
        uv_size=uv_size,
        rgb_path=output_mv_image,
        rgb_process_config=ModProcessConfig(view_upscale=True, inpaint_mode="view"),
        camera_azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
        move_to_center=True,
    )

    elapsed = time.time() - start_time
    print(f"Completed processing: {save_name}, time taken: {elapsed:.2f} seconds")
    print(f"Output saved to: {out.shaded_model_save_path}")
    return True


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--variant", type=str, default="sdxl", choices=["sdxl", "sd21"])
    # I/O
    parser.add_argument("--mesh", type=str, help="Path to a single mesh file")
    parser.add_argument("--image", type=str, help="Path to a single image file")
    parser.add_argument("--mesh_dir", type=str, help="Directory containing multiple mesh files")
    parser.add_argument("--image_dir", type=str, help="Directory containing multiple image files")
    parser.add_argument("--text", type=str, default="high quality")
    parser.add_argument("--seed", type=int, default=-1)
    parser.add_argument("--save_dir", type=str, default="./output")
    parser.add_argument("--save_name", type=str, default="i2tex_sample")
    # Extra
    parser.add_argument("--reference_conditioning_scale", type=float, default=1.0)
    parser.add_argument("--preprocess_mesh", action="store_true")
    parser.add_argument("--remove_bg", action="store_true")
    # New parameter - Select which checkpoint to use
    parser.add_argument("--use_which_ckpt", type=str, default="baseline", 
                        help="Specify which checkpoint to use (e.g., baseline, vae_feature_shuffle, etc.)")
    # Models
    parser.add_argument(
        "--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0"
    )
    parser.add_argument(
        "--vae_model", type=str, default="madebyollin/sdxl-vae-fp16-fix"
    )
    
    parser.add_argument("--img_patch_size", type=int, default=128)
    parser.add_argument("--mask_ratio", type=float, default=0.0)
    parser.add_argument("--is_vae_shuffle", action="store_true", default=False, 
                        help="Whether to shuffle vae patches during inference")
    parser.add_argument("--vae_patch_size", type=int, default=16)
    parser.add_argument("--is_image_shuffle", action="store_true", default=False, 
                        help="Whether to shuffle image patches during inference")
    # New parameter - Specify path to existing multi-view images
    parser.add_argument("--mv_path", type=str, default=None, 
                        help="Path to existing multi-view image. If provided, skip generation step")
    parser.add_argument("--split_img", action="store_true", default=False, 
                        help="")
    # New parameter - Skip existing outputs
    parser.add_argument("--skip_existing", action="store_true", default=False, 
                        help="Skip processing if output files already exist")
    # New parameter - Exclude mesh files containing specific string
    parser.add_argument("--exclude_str", type=str, default="_unwarp", 
                        help="Exclude mesh files containing this string")
    
    parser.add_argument("--bg_remove_or_padding", type=str, default="bg_padding", 
                    help="reference image background use remove method or padding method")
    
    parser.add_argument(
        "--upscaler_ckpt_path", type=str, default="../style-pretrained-ckpts/RealESRGAN_x2plus.pth"
    )
    parser.add_argument(
        "--inpaint_ckpt_path", type=str, default="../style-pretrained-ckpts/big-lama.pt"
    )
    parser.add_argument(
        "--jigsaw3D_path", type=str, default="../jigsaw3D_ckpt/image_shuffle_patch_size_64/mp_rank_00_model_states.pt"
    )

    args = parser.parse_args()
    
    # Ensure output directory exists
    os.makedirs(args.save_dir, exist_ok=True)
    
    # Initialize global variables
    device = args.device
    num_views = 6
    
    if args.variant == "sdxl": # TO-DO; can training a 768 scale
        from .inference_ig2mv_sdxl_for_texture_generation import remove_bg, run_pipeline
        base_model = args.base_model
        vae_model = args.vae_model
        height = width = 512
        uv_size = 2048
    elif args.variant == "sd21":
        from .inference_ig2mv_sd import remove_bg, run_pipeline
        base_model = "stabilityai/stable-diffusion-2-1-base"
        vae_model = None
        height = width = 512
        uv_size = 2048
    else:
        raise ValueError(f"Invalid variant: {args.variant}")
    
    # Prepare models (load only once)
    pipe = None
    texture_pipe = None
    birefnet = None
    transform_image = None
    remove_bg_fn = None
    
    # If multi-view images are not provided, load multi-view generation model
    if args.mv_path is None:
        print("======= Loading multi-view generation model ========")
        pipe = create_pipeline_from_deepspeed_checkpoint(
            device=device,
            dtype=torch.float16,
            num_views=num_views,
            scheduler=None,
            lora_model=None,
            base_model=base_model,
            vae_model=vae_model,
            use_which_ckpt=args.use_which_ckpt,
            jigsaw3D_path=args.jigsaw3D_path,
            
        )
        
        if args.remove_bg:
            print("======= Loading background removal model ========")
            birefnet = AutoModelForImageSegmentation.from_pretrained(
                "/root/autodl-tmp/style-pretrained-ckpts/BiRefNet", trust_remote_code=True
            )
            birefnet.to(args.device)
            transform_image = transforms.Compose(
                [
                    transforms.Resize((1024, 1024)),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                ]
            )
            remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, args.device)
    
    print("======= Loading texture back-projection model ========")
    texture_pipe = TexturePipeline(
        # upscaler_ckpt_path="/root/autodl-tmp/style-pretrained-ckpts/RealESRGAN_x2plus.pth",
        # inpaint_ckpt_path="/root/autodl-tmp/style-pretrained-ckpts/big-lama.pt",
        upscaler_ckpt_path=args.upscaler_ckpt_path,
        inpaint_ckpt_path=args.inpaint_ckpt_path,
        device=device,
    )
    
    
    # Exclusion list - Add mesh filenames you want to exclude
    EXCLUDE_MESHES = (
        # "16a8213983c4c96b64b1f4d4547591cf.glb",
        # "1a644d3434817c4732670edc1c6a72d6.glb",
        # "68fe08b14c37425dbc962a16689be229.glb",
        # Add more filenames to exclude if needed
    )
    
    # Process single file or batch processing
    if args.mesh_dir and args.image_dir:
        # Batch processing mode
        print(f"======= Batch processing mode: mesh_dir={args.mesh_dir}, image_dir={args.image_dir} ========")
        
        # Get all mesh files
        mesh_files = []
        for ext in ["*.glb", "*.obj", "*.fbx"]:
            mesh_files.extend(glob.glob(os.path.join(args.mesh_dir, ext)))
        
        # Filter out files containing exclusion string and files in exclusion list
        mesh_files = [
            f for f in mesh_files 
            if args.exclude_str not in os.path.basename(f) 
            and os.path.basename(f) not in EXCLUDE_MESHES
        ]
    
        # Get all image files
        image_files = []
        for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp"]:
            image_files.extend(glob.glob(os.path.join(args.image_dir, ext)))
        
        total_combinations = len(mesh_files) * len(image_files)
        processed = 0
        
        print(f"Found {len(mesh_files)} mesh files and {len(image_files)} image files")
        print(f"Total combinations to process: {total_combinations}")
        
        for mesh_path in mesh_files:
            mesh_name = os.path.splitext(os.path.basename(mesh_path))[0]
            
            for image_path in image_files:
                image_name = os.path.splitext(os.path.basename(image_path))[0]
                save_name = f"{mesh_name}_{image_name}"
                
                try:
                    success = process_single_pair(
                        args, mesh_path, image_path, save_name, 
                        pipe, texture_pipe, birefnet, transform_image,
                        height=height, width=width, num_views=num_views, uv_size=uv_size, 
                        bg_remove_or_padding=args.bg_remove_or_padding
                    )
    
                    if success:
                        processed += 1
                except Exception as e:
                    print(f"Error processing {save_name}: {str(e)}")
        
        print(f"Processing completed! Successfully processed {processed}/{total_combinations} combinations")
    
    elif args.mesh and args.image:
        # Single processing mode
        print("======= Single processing mode ========")
        save_name = args.save_name
        if save_name == "i2tex_sample":
            mesh_name = os.path.splitext(os.path.basename(args.mesh))[0]
            image_name = os.path.splitext(os.path.basename(args.image))[0]
            save_name = f"{mesh_name}_{image_name}"
        
        process_single_pair(
            args, args.mesh, args.image, save_name, 
            pipe, texture_pipe, birefnet, transform_image,
            height=height, width=width, num_views=num_views, uv_size=uv_size,
            bg_remove_or_padding=args.bg_remove_or_padding
        )
    
    
    else:
        print("Error: Must provide either single files (--mesh and --image) or directories (--mesh_dir and --image_dir)")
        sys.exit(1)
    
    # Clean up resources
    if pipe is not None:
        del pipe
    if texture_pipe is not None:
        del texture_pipe
    if birefnet is not None:
        del birefnet
    
    torch.cuda.empty_cache()
    print("Resources cleaned up")