import json
import os
import sys
import torch
import torch.distributed as dist
from torch.utils.data import Dataset, DistributedSampler, DataLoader
from PIL import Image
from hy3dgen.rembg import BackgroundRemover
from hy3dgen.shapegen import Hunyuan3DDiTFlowMatchingPipeline
from hy3dgen.texgen import Hunyuan3DPaintPipeline
from hy3dgen.text2image import HunyuanDiTPipeline
from tqdm import tqdm
import argparse
from decord import VideoReader, bridge

bridge.set_bridge("torch")
from torchvision import transforms
from torchvision.transforms import ToPILImage
import jsonlines
import csv
import numpy as np


class CSVDataset(Dataset):
    def __init__(
        self,
        csv_path="",
        part=None,
    ):
        self.videos = []
        self.masks = []
        with open(csv_path, "r") as f:
            reader = csv.reader(f)
            for row in reader:
                (
                    video_path,
                    prompt,
                    mask_path,
                    clip_sim,
                    mean_ratio,
                    min_iou,
                    all_iou,
                ) = row
                if (
                    float(mean_ratio) > 0.05
                    and float(mean_ratio) < 0.5
                    and float(all_iou) > 0.1
                ):

                    self.videos.append(video_path)
                    self.masks.append(mask_path)


        self.videos = self.videos[part * 5000 : (part + 1) * 5000]
        self.masks = self.masks[part * 5000 : (part + 1) * 5000]

    def __len__(self):
        return len(self.videos)

    def __getitem__(self, index):
        return self.videos[index], self.masks[index]


def init_dist():
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    rank = int(os.environ["RANK"])

    dist.init_process_group(backend="nccl", init_method="env://")
    torch.cuda.set_device(local_rank)
    return local_rank, rank, world_size


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--part", type=int, default=0)
    args = parser.parse_args()

    local_rank, rank, world_size = init_dist()
    dataset = CSVDataset(part=args.part)


    shape_pipe = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(
        "PATH_TO_YOUR_CHECKPOINT",
        subfolder="",
        variant="fp16",
    )
    paint_pipe = Hunyuan3DPaintPipeline.from_pretrained(
        "PATH_TO_YOUR_CHECKPOINT",
        subfolder="",
    )
    # super_net = imageSuperNet()

    rembg = BackgroundRemover()
    seed = 42
    tf = transforms.Compose(
        [
            transforms.Resize((720, 1280)),
        ]
    )
    # create per-rank output dirs
    mesh_dir = (
        f"./meshs_{args.part}_{local_rank}"
    )

    image_dir = f"./obj_images_{args.part}_{local_rank}"

    f = open(
        f"./metadata_{args.part}_{local_rank}.csv",
        "w",
    )

    os.makedirs(mesh_dir, exist_ok=True)
    os.makedirs(image_dir, exist_ok=True)
    # os.makedirs(gltf_dir, exist_ok=True)

    sampler = DistributedSampler(
        dataset, num_replicas=world_size, rank=rank, shuffle=False
    )
    loader = DataLoader(
        dataset, batch_size=1, sampler=sampler, num_workers=2, pin_memory=True
    )

    with torch.inference_mode():

        for index, batch in enumerate(tqdm(loader)):

            video_path, mask_path = batch
            video_path = video_path[0]
            mask_path = mask_path[0]

            vr = VideoReader(video_path)
            if len(vr) < 81:
                print(f"length too short, skip {video_path}")
                continue
            image = vr.get_batch([0])  # 0-255
            if image.shape[1] > image.shape[2]:
                del vr, image
                print(f"not wide, skip {video_path}")
                continue

            mask_vr = VideoReader(mask_path)
            mask = mask_vr.get_batch([0]) / 255.0
            mask = mask.clamp(0.0, 1.0).round()

            del vr, mask_vr

            mask = mask.permute(0, 3, 1, 2)


            image = image / 255.0
            image = image.permute(0, 3, 1, 2)

            image = tf(image)
            mask = tf(mask)

            obj_image = mask * image

            image = ToPILImage()(obj_image.squeeze(0))

            image_path = os.path.join(image_dir, f"{index}.png")
            mesh_path = os.path.join(mesh_dir, f"{index}.glb")

            if image.mode == "RGB" or image.mode == "RGBA":
                image = rembg(image)

            image.save(image_path)
            mesh = shape_pipe(
                image=image,
                num_inference_steps=50,
                octree_resolution=512,
                num_chunks=20000,
                generator=torch.manual_seed(seed),
                output_type="trimesh",
            )[0]

            torch.cuda.empty_cache()

            textured = paint_pipe(mesh, image=image)
            textured.export(mesh_path)
            print(
                f"{video_path},{mask_path},{image_path},{mesh_path}", file=f, flush=True
            )
            torch.cuda.empty_cache()

    f.close()
    dist.destroy_process_group()


if __name__ == "__main__":
    main()