#!/usr/bin/env python3
import os
import argparse
import math
import torch
import torch.nn.functional as F
from torch.optim import Adam
from PIL import Image
import torchvision.transforms.functional as TF
import kaolin as kal
from kaolin.render.camera import (
    generate_perspective_projection,
    generate_rotate_translate_matrices,
)
from kaolin.render.mesh import dibr_rasterization, prepare_vertices, dibr_soft_mask
from kaolin.metrics.render import mask_iou
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR
from torchvision import transforms
import csv

import torchvision.io as io
import time
import warnings
import trimesh
import numpy as np
import open3d as o3d

import multiprocessing as mp
from queue import Empty

warnings.filterwarnings("ignore", category=UserWarning)



class MeshPairDataset:
    def __init__(
        self,
        csv_path="",
        part=0,
    ):
        self.video_paths = []
        self.mask_paths = []
        self.mesh_paths = []

        with open(csv_path, "r") as f:
            reader = csv.reader(f)
            for line in reader:
                self.video_paths.append(line[0])
                self.mask_paths.append(line[1])
                self.mesh_paths.append(line[3])

        self.video_paths = self.video_paths[10000 * part : 10000 * (part + 1)]
        self.mask_paths = self.mask_paths[10000 * part : 10000 * (part + 1)]
        self.mesh_paths = self.mesh_paths[10000 * part : 10000 * (part + 1)]

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        return self.video_paths[idx], self.mask_paths[idx], self.mesh_paths[idx]


SMALL_RESIZE = transforms.Compose(
    [
        transforms.Resize((160, 240)),
    ]
)

LARGE_RESIZE = transforms.Compose(
    [
        transforms.Resize((320, 480)),
    ]
)


def get_K(mesh_path):
    size_bytes = os.path.getsize(mesh_path)
    size_mb = size_bytes / 1024**2
    if size_mb > 100:
        return size_mb, 2
    else:
        return size_mb, 4

def preprocessing(mask_path):
    reader = io.VideoReader(mask_path, "video")
    mask_frame = next(iter(reader))["data"].float()
    mask_frame = mask_frame.mean(dim=0).unsqueeze(0)
    del reader

    gt_mask = (LARGE_RESIZE(mask_frame) > 10).float()
    small_mask = (SMALL_RESIZE(mask_frame) > 10).float()

    _, ys, xs = (small_mask > 0.5).nonzero(as_tuple=True)
    cx, cy = xs.float().mean(), ys.float().mean()

    rx0 = (cx / 240 - 0.5) * 2.0
    ry0 = (cy / 160 - 0.5) * 2.0

    return gt_mask.cuda(), small_mask.cuda(), rx0, ry0


def rendering(yaw, pitch, dist, rx, ry, gt_mask, verts, faces, proj, backend):
    device = yaw.device
    batch_size = yaw.shape[0]

    ### y_up
    x = dist * torch.cos(pitch) * torch.cos(yaw)
    y = dist * torch.sin(pitch)
    z = dist * torch.cos(pitch) * torch.sin(yaw)
    eye = torch.stack([x, y, z], dim=1)  # [M,3]

    target = torch.zeros(1, 3, device=device).expand(batch_size, -1)
    up = (
        torch.tensor([0.0, 1.0, 0.0], device=device).unsqueeze(0).expand(batch_size, -1)
    )

    rot_c, trans_c = generate_rotate_translate_matrices(eye, target, up)

    # batch vertices
    verts_b = verts.unsqueeze(0).expand(batch_size, -1, -1)  # [M,V,3]

    # project and rasterize at low-res
    fcam, fimg, fnorm = prepare_vertices(
        vertices=verts_b,
        faces=faces,
        camera_proj=proj,
        camera_rot=rot_c,
        camera_trans=trans_c,
    )

    fimg[:, :, :, :1] += rx.view(batch_size, 1, 1, 1)
    fimg[:, :, :, 1:] += ry.view(batch_size, 1, 1, 1)

    rendered, soft_mask, face_idx = dibr_rasterization(
        height=gt_mask.shape[1],
        width=gt_mask.shape[2],
        face_vertices_z=fcam[..., 2],
        face_vertices_image=fimg,
        face_features=[torch.ones((batch_size, faces.shape[0], 3, 1), device=device)],
        face_normals_z=fnorm[..., 2],
        rast_backend=backend,
    )

    loss, ious = mask_iou(soft_mask, gt_mask.repeat(batch_size, 1, 1))
    with torch.no_grad():
        _, hard_ious = mask_iou(
            (soft_mask > 0.1).float(), gt_mask.repeat(batch_size, 1, 1)
        )

    return loss, ious, soft_mask, hard_ious


def preprocessing_mesh(path: str, target_faces: int = 50000):
    mesh_o3d = o3d.io.read_triangle_mesh(path)
    mesh_o3d.compute_vertex_normals()
    F = np.asarray(mesh_o3d.triangles).shape[0]

    if F > target_faces:
        mesh_o3d = mesh_o3d.simplify_quadric_decimation(target_faces)

    verts_np = np.asarray(mesh_o3d.vertices, dtype=np.float32)
    faces_np = np.asarray(mesh_o3d.triangles, dtype=np.int64)

    centroid = verts_np.mean(axis=0, keepdims=True)
    verts_np -= centroid

    m = np.abs(verts_np).max()
    if m > 0:
        verts_np /= 2 * m

    verts = torch.from_numpy(verts_np).cuda()
    faces = torch.from_numpy(faces_np).cuda()
    return verts, faces


def parse_devices(dev_arg: str | None):
    if dev_arg:
        return [int(x) for x in dev_arg.split(",") if x.strip() != ""]
    vis = os.environ.get("CUDA_VISIBLE_DEVICES")
    if vis:
        return list(range(len(vis.split(","))))
    return list(range(torch.cuda.device_count()))


def worker(proc_id, device_id, index_queue, dataset, out_dir, part, args):
    torch.cuda.set_device(device_id)
    device = torch.device(f"cuda:{device_id}")

    save_dir = os.path.join(out_dir, f"{part}_{device_id}")
    os.makedirs(save_dir, exist_ok=True)
    meta_path = os.path.join(out_dir, f"{part}_{device_id}.csv")

    fovy = math.pi / 3
    proj = generate_perspective_projection(fovyangle=fovy, ratio=480 / 320).cuda()

    with open(meta_path, "w") as meta_file:
        while True:
            try:
                batch_index = index_queue.get_nowait()
            except Empty:
                break

            video_path, mask_path, mesh_path = dataset[batch_index]

            # Load & normalize mesh
            start = time.time()
            vertices, faces = preprocessing_mesh(mesh_path)
            mesh_time = time.time() - start

            gt_mask, gt_mask_c, rx0, ry0 = preprocessing(mask_path)

            M = args.M
            yaw_min, yaw_max = 0.0, 2 * math.pi
            pitch_min, pitch_max = 0.0, math.pi
            dist_min, dist_max = 0.2, 3

            yaws, pitchs, dists, rxs, rys = [], [], [], [], []
            count = 0
            t1 = time.time()
            file_size, K = get_K(mesh_path)

            while len(yaws) < 16:
                yaw = torch.rand(M, device=device) * (yaw_max - yaw_min) + yaw_min
                pitch = (
                    torch.rand(M, device=device) * (pitch_max - pitch_min) + pitch_min
                )
                dist = torch.rand(M, device=device) * (dist_max - dist_min) + dist_min
                rx = torch.empty(M, device=device).uniform_(-0.4, 0.4) + rx0
                ry = torch.empty(M, device=device).uniform_(-0.4, 0.4) + ry0

                with torch.no_grad():
                    _, _, _, hard_ious = rendering(
                        yaw,
                        pitch,
                        dist,
                        rx,
                        ry,
                        gt_mask_c,
                        vertices,
                        faces,
                        proj,
                        "nvdiffrast_fwd",
                    )

                index = torch.nonzero(hard_ious > 0.4, as_tuple=True)[0]
                yaws.extend(yaw[index].tolist())
                pitchs.extend(pitch[index].tolist())
                dists.extend(dist[index].tolist())
                rxs.extend(rx[index].tolist())
                rys.extend(ry[index].tolist())
                count += 1

                if count > 200:
                    break

            if len(yaws) == 0:
                print(
                    f"[GPU {device_id}] skip idx={batch_index} (no coarse candidates)"
                )
                continue

            t2 = time.time()
            torch.cuda.empty_cache()

            yaws = torch.tensor(yaws, device=device)[:12]
            pitchs = torch.tensor(pitchs, device=device)[:12]
            dists = torch.tensor(dists, device=device)[:12]
            rxs = torch.tensor(rxs, device=device)[:12]
            rys = torch.tensor(rys, device=device)[:12]

            yaw_f = torch.nn.Parameter(yaws)
            pitch_f = torch.nn.Parameter(pitchs)
            dist_f = torch.nn.Parameter(dists)
            rx_f = torch.nn.Parameter(rxs)
            ry_f = torch.nn.Parameter(rys)

            optimizer = Adam([yaw_f, pitch_f, dist_f, rx_f, ry_f], lr=args.lr)
            scheduler = StepLR(optimizer, step_size=20, gamma=0.5)
            torch.cuda.empty_cache()

            best_iou = torch.tensor(0.0, device=device)
            for i in range(args.fine_iters):
                optimizer.zero_grad()
                try:
                    loss, _, soft_mask, hard_ious = rendering(
                        yaw_f,
                        pitch_f,
                        dist_f,
                        rx_f,
                        ry_f,
                        gt_mask,
                        vertices,
                        faces,
                        proj,
                        "nvdiffrast",
                    )
                except Exception as e:
                    print(f"[GPU {device_id}] potential OOM mesh size: {file_size}")
                    print(e)
                    return 

                loss.backward()
                optimizer.step()
                scheduler.step()

                cur_best_iou, best_idx = hard_ious.max(0)
                best_iou = cur_best_iou

                if cur_best_iou >= args.iou_thresh or (i == args.fine_iters - 1):
                    final_mask = soft_mask[best_idx]
                    out = TF.to_pil_image(final_mask.clamp(0, 1))
                    save_path = os.path.join(save_dir, f"{batch_index}.png")
                    out.save(save_path)

                    print(
                        f"{video_path},{mask_path},{mesh_path},{save_path},"
                        f"{yaw_f[best_idx]:.6f},{pitch_f[best_idx]:.6f},{dist_f[best_idx]:.6f},"
                        f"{rx_f[best_idx]:.6f},{ry_f[best_idx]:.6f},{cur_best_iou:.6f}",
                        file=meta_file,
                        flush=True,
                    )
                    break

            t3 = time.time()
            print(
                f"[GPU {device_id}] idx={batch_index} | best_iou={best_iou.item():.4f} | "
                f"prep={mesh_time:.3f}s | search={(t2-t1):.3f}s | optimize={(t3-t2):.3f}s | sizeMB={file_size:.1f}"
            )


def main():
    p = argparse.ArgumentParser(
        description="Coarse-to-fine camera estimation via silhouette IoU (multiprocessing)"
    )
    p.add_argument(
        "--out_dir",
        default="",
        help="Output directory",
    )
    p.add_argument(
        "--M", type=int, default=128, help="Number of random samples in coarse stage"
    )
    p.add_argument(
        "--fine_iters", type=int, default=200, help="Iterations for fine optimization"
    )
    p.add_argument(
        "--iou_thresh", type=float, default=0.9, help="Early stop IoU threshold"
    )
    p.add_argument(
        "--lr", type=float, default=1e-1, help="Learning rate for fine stage"
    )
    p.add_argument("--part", type=int, required=True)
    p.add_argument(
        "--devices",
        type=str,
        default="0,1,2,3,4,5,6,7",
        help="Comma-separated GPU ids, e.g. '0,1,2,3'",
    )
    p.add_argument(
        "--csv_path",
        type=str,
        default="",
    )

    args = p.parse_args()
    os.makedirs(args.out_dir, exist_ok=True)

    devices = parse_devices(args.devices)
    dataset = MeshPairDataset(csv_path=args.csv_path, part=args.part)


    index_queue = mp.Queue()
    for i in range(len(dataset)):
        index_queue.put(i)

    procs = []
    for proc_id, dev in enumerate(devices):
        p = mp.Process(
            target=worker,
            args=(proc_id, dev, index_queue, dataset, args.out_dir, args.part, args),
        )
        p.start()
        procs.append(p)

    for p in procs:
        p.join()


if __name__ == "__main__":
    import multiprocessing as mp

    if mp.get_start_method(allow_none=True) != "spawn":
        mp.set_start_method("spawn", force=True)
    main()