import os
import argparse
import math
import torch
from PIL import Image
import torchvision.transforms.functional as TF
import kaolin as kal
from kaolin.render.camera import (
    generate_perspective_projection,
    generate_rotate_translate_matrices,
)
from kaolin.render.mesh import dibr_rasterization, prepare_vertices, texture_mapping

from tqdm import tqdm
import torch.distributed as dist
from torch.utils.data import Dataset, DistributedSampler, DataLoader

from torchvision import transforms
import csv

import torchvision.io as io
import time
import warnings
import trimesh
import numpy as np
import open3d as o3d

warnings.filterwarnings("ignore", category=UserWarning)


class MeshPairDataset(Dataset):
    def __init__(
        self,
        rm_csv="",
        mesh_csv="",
        camera_csv="",
        part=0,
    ):
        mask_paths = []
        self.pairs = []
        mesh_map = {}
        with open(mesh_csv, "r") as f:
            reader = csv.reader(f)
            for line in reader:
                mask_path = line[1]
                mask_paths.append(mask_path)
                mesh_path = line[2]
                # mesh_path = line[3]
                mesh_map[mask_path] = {"mesh_path": mesh_path}

        with open(camera_csv, "r") as f:
            reader = csv.reader(f)
            for line in reader:
                video_path, mask_path = line[0], line[1]
                yaw, pitch, distance, rx, ry, iou = (
                    float(line[4]),
                    float(line[5]),
                    float(line[6]),
                    float(line[7]),
                    float(line[8]),
                    float(line[9]),
                )
                if mask_path in mesh_map:
                    temp_dic = {
                        "yaw": yaw,
                        "pitch": pitch,
                        "distance": distance,
                        "rx": rx,
                        "ry": ry,
                    }
                    mesh_map[mask_path].update(temp_dic)

        with open(rm_csv, "r") as f:
            reader = csv.reader(f)
            for line in reader:
                rm_mask_path = line[1]
                if (
                    rm_mask_path in mesh_map
                    and "yaw" in mesh_map[rm_mask_path]
                    # and rm_mask_path not in partials
                ):
                    # mask_path, mesh_path, rm_video_path, ori_video_path
                    self.pairs.append(
                        (
                            rm_mask_path,
                            mesh_map[rm_mask_path]["mesh_path"],
                            line[2],
                            line[0],
                            mesh_map[rm_mask_path]["yaw"],
                            mesh_map[rm_mask_path]["pitch"],
                            mesh_map[rm_mask_path]["distance"],
                            mesh_map[rm_mask_path]["rx"],
                            mesh_map[rm_mask_path]["ry"],
                        )
                    )

        print(len(self.pairs))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):

        return self.pairs[idx]


LARGE_RESIZE = transforms.Compose(
    [
        transforms.Resize((640, 960)),
    ]
)


def rendering(yaw, pitch, dist, rx, ry, verts, faces, proj, face_uvs):
    device = yaw.device
    batch_size = yaw.shape[0]

    ### y_up
    x = dist * torch.cos(pitch) * torch.cos(yaw)
    y = dist * torch.sin(pitch)
    z = dist * torch.cos(pitch) * torch.sin(yaw)
    eye = torch.stack([x, y, z], dim=1)  # [M,3]

    target = torch.zeros(1, 3, device=device).expand(batch_size, -1)

    up = (
        torch.tensor([0.0, 1.0, 0.0], device=device).unsqueeze(0).expand(batch_size, -1)
    )

    rot_c, trans_c = generate_rotate_translate_matrices(eye, target, up)

    # batch vertices
    verts_b = verts.unsqueeze(0).expand(batch_size, -1, -1)  # [M,V,3]

    fcam, fimg, fnorm = prepare_vertices(
        vertices=verts_b,
        faces=faces,
        camera_proj=proj,
        camera_rot=rot_c,
        camera_trans=trans_c,
    )

    fimg[:, :, :, :1] += rx.view(batch_size, 1, 1, 1)
    fimg[:, :, :, 1:] += ry.view(batch_size, 1, 1, 1)

    rendered, soft_mask, face_idx = dibr_rasterization(
        height=640,
        width=960,
        face_vertices_z=fcam[..., 2],
        face_vertices_image=fimg,
        face_features=[
            torch.ones((batch_size, faces.shape[0], 3, 1), device=device),
            face_uvs.repeat(batch_size, 1, 1, 1),
        ],
        face_normals_z=fnorm[..., 2],
        rast_backend="nvdiffrast_fwd",
    )

    return rendered[0], rendered[1]


def init_dist():
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    rank = int(os.environ["RANK"])

    dist.init_process_group(backend="gloo", init_method="env://")
    torch.cuda.set_device(local_rank)
    return local_rank, rank, world_size


def preprocessing_mesh(
    path: str,
):

    model = o3d.io.read_triangle_model(path)
    mesh_o3d = model.meshes[0].mesh

    textures = np.asarray(model.materials[0].albedo_img, dtype=np.uint8)

    mesh_o3d.compute_vertex_normals()
    F = np.asarray(mesh_o3d.triangles).shape[0]

    # --- 2) decimate if needed ---
    # if F > target_faces:
    # mesh_o3d = mesh_o3d.simplify_quadric_decimation(target_faces)

    # --- 3) extract arrays and normalize ---
    verts_np = np.asarray(mesh_o3d.vertices, dtype=np.float32)  # (V,3)
    faces_np = np.asarray(mesh_o3d.triangles, dtype=np.int64)  # (F',3)
    triangle_uvs = np.asarray(mesh_o3d.triangle_uvs, dtype=np.float32)

    # center at origin
    centroid = verts_np.mean(axis=0, keepdims=True)
    verts_np -= centroid

    # scale so max absolute coord = 0.5
    m = np.abs(verts_np).max()
    if m > 0:
        verts_np /= 2 * m

    # --- 4) to torch tensors ---
    verts = torch.from_numpy(verts_np).cuda()  # float32, (V,3)
    faces = torch.from_numpy(faces_np).cuda()  # int64, (F',3)
    triangle_uvs = torch.from_numpy(triangle_uvs).cuda()
    textures = torch.from_numpy(textures).cuda() / 255.0

    triangle_uvs = triangle_uvs.view(
        faces.shape[0], faces.shape[1], triangle_uvs.shape[-1]
    )


    return verts, faces, triangle_uvs, textures


def main():

    p = argparse.ArgumentParser()
    p.add_argument(
        "--out_dir",
        default="",
        help="Output directory",
    )

    p.add_argument(
        "--M", type=int, default=24,
    )

    p.add_argument(
        "--part", type=int, default=0,
        )

    args = p.parse_args()
    local_rank, rank, world_size = init_dist()
    dataset = MeshPairDataset(part=args.part)

    sampler = DistributedSampler(
        dataset, num_replicas=world_size, rank=rank, shuffle=False
    )
    loader = DataLoader(
        dataset, batch_size=1, sampler=sampler, num_workers=4, pin_memory=True
    )

    device = f"cuda:{local_rank}"
    image_dir = os.path.join(args.out_dir, f"image")
    mask_dir = os.path.join(args.out_dir, f"mask")

    os.makedirs(image_dir, exist_ok=True)
    os.makedirs(mask_dir, exist_ok=True)

    meta_file = open(f"{args.out_dir}/{args.part}_{local_rank}.csv", "w")

    bar = tqdm(total=len(loader))
    for batch_index, batch in enumerate(loader):

        (
            mask_path,
            mesh_path,
            rm_video_path,
            ori_video_path,
            yaw,
            pitch,
            distance,
            rx,
            ry,
        ) = batch

        yaw = yaw.float().cuda()
        pitch = pitch.float().cuda()
        distance = distance.float().cuda()
        rx = rx.float().cuda()
        ry = ry.float().cuda()

        mask_path = mask_path[0]
        rm_video_path = rm_video_path[0]
        ori_video_path = ori_video_path[0]
        mesh_path = mesh_path[0]

        vertices, faces, triangle_uvs, textures = preprocessing_mesh(mesh_path)

        fovy = math.pi / 3
        proj = generate_perspective_projection(fovyangle=fovy, ratio=960 / 640).cuda()

        M = args.M
        yaw_min, yaw_max = yaw - math.pi / 8, yaw + math.pi / 8

        pitch_min, pitch_max = pitch - math.pi / 8, pitch + math.pi / 8

        dist_min, dist_max = distance - 0.3, distance + 0.3

        dist_min = max(dist_min, 0.5)

        for times in range(20):
            yaw = torch.rand(M, device=device) * (yaw_max - yaw_min) + yaw_min
            pitch = torch.rand(M, device=device) * (pitch_max - pitch_min) + pitch_min
            distance = torch.rand(M, device=device) * (dist_max - dist_min) + dist_min
            rx = torch.empty(M, device=device).uniform_(-0.2, 0.2) + rx
            ry = torch.empty(M, device=device).uniform_(-0.1, 0.1) + ry

            with torch.no_grad():
                soft_mask, texture_coords = rendering(
                    yaw,
                    pitch,
                    distance,
                    rx,
                    ry,
                    vertices,
                    faces,
                    proj,
                    triangle_uvs,
                )

                image = texture_mapping(
                    texture_coords,
                    textures.permute(2, 0, 1).repeat(M, 1, 1, 1),
                    mode="bilinear",
                )
                image = image.permute(0, 3, 1, 2)
                soft_mask = soft_mask.permute(0, 3, 1, 2)
                soft_mask = (soft_mask > 0.2).float()

                for ind in range(image.shape[0]):

                    out_image = TF.to_pil_image(image[ind].clamp(0, 1))

                    render_image_path = os.path.join(
                        image_dir,
                        f"{args.part}_{batch_index}_{local_rank}_{times*M+ind}.png",
                    )

                    out_image.save(render_image_path)

                    out_mask = TF.to_pil_image(soft_mask[ind].clamp(0, 1))
                    render_mask_path = os.path.join(
                        mask_dir,
                        f"{args.part}_{batch_index}_{local_rank}_{times*M+ind}.png",
                    )
                    out_mask.save(render_mask_path)

                    print(
                        f"{ori_video_path},{mask_path},{mesh_path},{rm_video_path},{render_image_path},{render_mask_path},{yaw[ind]:.6f},{pitch[ind]:.6f},{distance[ind]:.6f},{rx[ind]:.6f},{ry[ind]:.6f}",
                        file=meta_file,
                        flush=True,
                    )

            torch.cuda.empty_cache()
        bar.update(1)



if __name__ == "__main__":
    main()