import torch
import kornia
from tqdm import tqdm
import torch.nn as nn
import os
import time
import multiprocessing as mp
from functools import partial

torch.backends.cudnn.benchmark = True
import rootutils

rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)


class IouCalculator(nn.Module):
    def __init__(self, h, w):
        super(IouCalculator, self).__init__()
        meshgrid = kornia.create_meshgrid(
            h, w, normalized_coordinates=False, device="cpu"
        )
        meshgrid = torch.cat(
            [meshgrid, torch.ones(1, h, w, 1, device="cpu")], dim=-1
        )  # [1, H, W, 3]
        self.register_buffer("meshgrid", meshgrid)

    def forward(self, depth1, c2w1, depth2, c2w2, K, inv_K):
        """
        Args:
            depth1: [B, H, W]
            c2w1: [B, 4, 4]
            depth2: [B, H, W]
            c2w2: [B, 4, 4]
            inverse_c2w2: [B, 4, 4]
            K: [B, 3, 3]
            inv_K: [B, 3, 3]
        Returns:
            iou: [B]
        """
        B, H, W = depth1.shape
        pixels = self.meshgrid.expand(B, -1, -1, -1).to(depth1.device)  # [B, H, W, 3]
        pixels = pixels.view(B, -1, 3)  # [B, H*W, 3]
        pixels = pixels.permute(0, 2, 1)  # [B, 3, H*W]
        points_cam1 = torch.bmm(inv_K, pixels)  # [B, 3, H*W]
        points_cam1 = points_cam1 * depth1.view(B, 1, -1)  # [B, 3, H*W]
        R1 = c2w1[:, :3, :3]  # [B, 3, 3]
        t1 = c2w1[:, :3, 3].unsqueeze(-1)  # [B, 3, 1]
        points_world = torch.bmm(R1, points_cam1) + t1  # [B, 3, H*W]
        R2 = c2w2[:, :3, :3]  # [B, 3, 3]
        t2 = c2w2[:, :3, 3].unsqueeze(-1)  # [B, 3, 1]
        inv_R2 = R2.permute(0, 2, 1)  # [B, 3, 3] (R2.T) for inverse
        points_cam2 = torch.bmm(inv_R2, points_world - t2)  # [B, 3, H*W]
        pixels2 = torch.bmm(K, points_cam2)  # [B, 3, H*W]
        z = pixels2[:, 2, :].clamp(min=1e-6)  # Avoid division by zero
        pixels2 = pixels2[:, :2, :] / z.unsqueeze(1)  # [B, 2, H*W]
        uv_round = pixels2.round().long()  # (B, 2, H*W)
        u, v = uv_round[:, 0, :], uv_round[:, 1, :]  # (B, H*W)
        valid_mask = (u >= 0) & (u < W) & (v >= 0) & (v < H) & (depth1.view(B, -1) > 0)
        u_clamped = u.clamp(0, W - 1)
        v_clamped = v.clamp(0, H - 1)
        batch_indices = torch.arange(B, device=depth1.device)[:, None].expand(-1, H * W)
        actual_depths = depth2[batch_indices, v_clamped, u_clamped]  # (B, H*W)
        projected_depths = points_cam2[:, 2, :]  # (B, H*W)
        depth_valid = actual_depths > 0
        depth_diff = torch.abs(projected_depths - actual_depths)
        overlap_mask = (depth_diff < 0.1) & valid_mask & depth_valid
        intersection = overlap_mask.sum(dim=1).float()  # (B,)
        union = (
            valid_mask.sum(dim=1) + (depth2 > 0).sum(dim=(1, 2)) - intersection
        ).float()
        iou = intersection / (union + 1e-6)
        iou[union <= 0] = 0.0  # Handle empty union cases
        return iou


def process_single_scan(scan_path, scan, gpu_id, counter, lock):
    torch.cuda.set_device(gpu_id)
    device = torch.device(f"cuda:{gpu_id}")
    process_begin = time.time()

    if os.path.exists(os.path.join(scan_path, scan, "iou.pt")):
        print(f"{scan} on gpu {gpu_id} already processed")
        with lock:
            counter.value += 1
        return

    # 初始化模型
    iou_calculator = IouCalculator(256, 256).to(device)

    # 加载数据
    items = os.listdir(os.path.join(scan_path, scan, "depth"))
    items = [int(img.split(".")[0]) for img in items]
    items = sorted(items)

    data = torch.load(
        os.path.join(scan_path, scan, "DKP.pt"),
        map_location=lambda storage, loc: storage.cuda(gpu_id),
        weights_only=True,
    )
    depths = data["depths"].to(device)
    poses = data["poses"].to(device)
    intrinsic = data["intrinsic"][:3, :3].to(device)
    inv_K = torch.inverse(intrinsic)

    # 生成图像对
    pairs = []
    for i, idx1 in enumerate(items):
        for j, idx2 in enumerate(items[i:]):
            if idx2 - idx1 > 100:
                break
            pairs.append(
                {
                    "idx1": idx1,
                    "idx2": idx2,
                    "depth1": depths[i],
                    "c2w1": poses[i],
                    "depth2": depths[i + j + 1],
                    "c2w2": poses[i + j + 1],
                    "K": intrinsic,
                    "inv_K": inv_K,
                }
            )

    result = torch.zeros(
        items[-1] + 1, items[-1] + 1, device="cpu", dtype=torch.float, pin_memory=True
    )

    def auto_batch_size(current_batch_size=4096):
        total_mem = torch.cuda.get_device_properties(device).total_memory
        free_mem = total_mem - torch.cuda.memory_allocated(device)
        sample_mem = 256 * 256 * 4 * 4 * 2
        max_batch = free_mem // sample_mem
        return max(1, min(current_batch_size // 2, max_batch // 2))  # 保留安全余量

    batchsize = auto_batch_size()
    batchstart = 0
    while batchstart < len(pairs):
        i = batchstart
        batchstart += batchsize
        batch = pairs[i : i + batchsize]
        batch_data = {
            "depth1": torch.stack([x["depth1"] for x in batch]).to(device),
            "c2w1": torch.stack([x["c2w1"] for x in batch]).to(device),
            "depth2": torch.stack([x["depth2"] for x in batch]).to(device),
            "c2w2": torch.stack([x["c2w2"] for x in batch]).to(device),
            "K": torch.stack([x["K"] for x in batch]).to(device),
            "inv_K": torch.stack([x["inv_K"] for x in batch]).to(device),
        }
        try:
            with torch.no_grad():
                iou = iou_calculator(
                    batch_data["depth1"],
                    batch_data["c2w1"],
                    batch_data["depth2"],
                    batch_data["c2w2"],
                    batch_data["K"],
                    batch_data["inv_K"][:, :3, :3],
                ).cpu()

            del batch_data
            torch.cuda.empty_cache()

            for j in range(len(iou)):
                result[batch[j]["idx1"], batch[j]["idx2"]] = iou[j]
        except RuntimeError as e:
            if "CUDA out of memory" in str(e):
                new_batchsize = auto_batch_size(batchsize)
                print(
                    f"process {scan} on gpu {gpu_id}, batchsize {batchsize} too large, reduce to {new_batchsize}"
                )
                batchstart = i
                batchsize = new_batchsize
            else:
                raise e

    result = result + result.t()
    torch.save(result, os.path.join(scan_path, scan, "iou.pt"))
    process_end = time.time()
    with lock:
        counter.value += 1
    print(
        f"{scan} on gpu {gpu_id} processed in {process_end - process_begin:.2f} seconds, {counter.value} scans processed"
    )


if __name__ == "__main__":
    num_gpus = torch.cuda.device_count()
    scan_path = "./data/scannet/val"
    scans = [
        scan
        for scan in os.listdir(scan_path)
        if os.path.isdir(os.path.join(scan_path, scan))
    ]
    scans = sorted(scans)

    pool = mp.Pool(processes=num_gpus)

    gpu_ids = [i % num_gpus for i in range(len(scans))]

    manager = mp.Manager()
    progress_counter = manager.Value("i", 0)
    lock = manager.Lock()

    for _ in pool.starmap(
        process_single_scan,
        zip(
            scan_path,
            scans,
            gpu_ids,
            [progress_counter] * len(scans),
            [lock] * len(scans),
        ),
    ):
        pass

    pool.close()
    pool.join()
