import torch
from vggt.models.vggt import VGGT
from vggt.utils.load_fn import load_and_preprocess_images
from vggt.utils.geometry import unproject_depth_map_to_point_map
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
import numpy as np
import imageio
import os
from plyfile import PlyData, PlyElement

def visualize_depth(depth, save_path=None):
    """
    Visualize depth map using cv2.
    Args:
    depth: Depth map as a numpy array (H,W) or (H,W,1)
    save_path: Path to save the visualization image, if None, image won't be saved
    Returns:
    Depth visualization as a numpy array (grayscale)
    """
    depth_vis = depth.copy()

    # 确保深度图是2D的
    if len(depth_vis.shape) == 3 and depth_vis.shape[-1] == 1:
        depth_vis = depth_vis.squeeze(-1)  # 移除最后一个维度，变成(H,W)

    # 确保深度图是2D的
    if len(depth_vis.shape) != 2:
        raise ValueError(f"Depth map must be 2D after processing, got shape {depth_vis.shape}")

    # Direct normalization of all values
    min_val = depth_vis.min()
    max_val = depth_vis.max()

    if max_val > min_val:
        depth_vis = (depth_vis - min_val) / (max_val - min_val) * 255
    else:
        depth_vis = np.zeros_like(depth_vis)

    depth_vis = depth_vis.astype(np.uint8)

    # Save the visualization to disk if path is specified
    if save_path is not None:
        # 确保保存目录存在
        os.makedirs(os.path.dirname(os.path.abspath(save_path)) if os.path.dirname(save_path) else '.', exist_ok=True)
        imageio.imwrite(save_path, depth_vis)

    return depth_vis

def batch_visualize_depth(depth_map, save_path=None):
    """
    Visualize a batch of depth maps.
    Args:
        depth_map: Batch of depth maps as a torch tensor or numpy array
        save_path: Path to save the visualization images, if None, images won't be saved
    """
    # 转换为numpy数组，如果是torch张量
    if isinstance(depth_map, torch.Tensor):
        depth_map = depth_map.detach().cpu().numpy()

    # 确保目录存在
    if save_path:
        os.makedirs(save_path, exist_ok=True)
    
    # 处理不同维度的批次深度图
    if len(depth_map.shape) == 5:  # (B,N,H,W,1)
        batch_size, num_images = depth_map.shape[:2]
        for b in range(batch_size):
            for i in range(num_images):
                file_path = os.path.join(save_path, f"{os.path.splitext(os.path.basename(image_names[i]))[0]}.png") if save_path else None
                visualize_depth(depth_map[b, i], file_path)
    elif len(depth_map.shape) == 4:  # (B,H,W,1) or (N,H,W,1)
        batch_size = depth_map.shape[0]
        for i in range(batch_size):
            file_path = os.path.join(save_path, f"depth_map_{i}.png") if save_path else None
            visualize_depth(depth_map[i], file_path)
    else:
        raise ValueError(f"Unsupported depth map shape: {depth_map.shape}")
    
def batch_save_depth(depth_map, save_path=None):
    """
    Save a batch of depth maps.
    Args:
        depth_map: Batch of depth maps as a torch tensor or numpy array
        save_path: Path to save the visualization images, if None, images won't be saved
    """
    # 转换为numpy数组，如果是torch张量
    if isinstance(depth_map, torch.Tensor):
        depth_map = depth_map.detach().cpu().numpy()
    
    # 确保目录存在
    if save_path:
        os.makedirs(save_path, exist_ok=True)

    if len(depth_map.shape) == 5:  # (B,N,H,W,1)
        batch_size, num_images = depth_map.shape[:2]
        for b in range(batch_size):
            for i in range(num_images):
                file_path = os.path.join(save_path, f"{os.path.splitext(os.path.basename(image_names[i]))[0]}.npy") if save_path else None
                np.save(file_path, depth_map[b, i])
    elif len(depth_map.shape) == 4:  # (B,H,W,1) or (N,H,W,1)
        batch_size = depth_map.shape[0]
        for i in range(batch_size):
            file_path = os.path.join(save_path, f"{os.path.splitext(os.path.basename(image_names[i]))[0]}.npy") if save_path else None
            np.save(file_path, depth_map[i])
    else:
        raise ValueError(f"Unsupported depth map shape: {depth_map.shape}")

def vggt_run(image_file_names, device="cpu"):
    # bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+) 
    dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16

    # Initialize the model and load the pretrained weights.
    # This will automatically download the model weights the first time it's run, which may take a while.
    model = VGGT.from_pretrained("./vggt/facebook/VGGT-1B").to(device)

    # Load and preprocess example images (replace with your own image paths)
    images = load_and_preprocess_images(image_file_names).to(device)
    images = images[None]

    with torch.no_grad() and torch.amp.autocast(dtype=dtype, device_type=device):
        # scene_dir_name = os.path.basename(os.path.dirname(os.path.dirname(image_names[0])))
        # output_dir = f"output/{scene_dir_name}"
        # os.makedirs(output_dir, exist_ok=True)

        aggregated_tokens_list, ps_idx = model.aggregator(images)

        # Predict Cameras
        pose_enc = model.camera_head(aggregated_tokens_list)[-1]
        # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
        extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])
        # np.save(os.path.join(output_dir, "intrinsic.npy"), intrinsic.detach().cpu().numpy().squeeze(0))

        # # Predict Depth Maps
        # depth_map, depth_conf = model.depth_head(aggregated_tokens_list, images, ps_idx)

        # batch_visualize_depth(depth_map, output_dir)
        # batch_save_depth(depth_map, output_dir)

        # # Predict Point Maps
        # point_map, point_conf = model.point_head(aggregated_tokens_list, images, ps_idx)
    return intrinsic.detach().cpu().numpy().squeeze(0), extrinsic.detach().cpu().numpy().squeeze(0)

            # point_map_by_unprojection = unproject_depth_map_to_point_map(depth_map.detach().squeeze(0).cpu().numpy(), 
            #                                                             extrinsic.detach().squeeze(0).cpu().numpy(), 
            #                                                             intrinsic.detach().squeeze(0).cpu().numpy())
            # image_col = images.cpu().detach().numpy()[0,0].reshape(-1, 3) *255
            # image_col = image_col.astype(np.uint8)
            # storePly(os.path.join(output_dir, "points3D.ply"), point_map_by_unprojection.reshape(-1, 3), image_col)