"""
Modified from https://github.com/paulengstler/invisible-stitch/blob/main/run.py

The original code moves the camera with its center almost fixed.
To match the condition with ours, the following changes are added:
- the scene point cloud is centered at the origin, and
- the camera traverses the orbital trajectory.
"""
import numpy as np
import skimage
import torch
from diffusers.utils import export_to_gif
from huggingface_hub import hf_hub_download
from PIL import Image
from pytorch3d.structures import Pointclouds
from pytorch3d.renderer import (
    look_at_view_transform,
    PerspectiveCameras,
)
from tqdm.auto import tqdm

from utils.models import get_zoe_dc_model, get_sd_pipeline, infer_with_zoe_dc
from utils.ops import snap_high_gradients_to_nn, project_points, get_pointcloud, merge_pointclouds, outpaint_with_depth_estimation
from utils.render import render


device = torch.device("cuda:0")
torch.cuda.set_device(device)

zoe_dc_model = get_zoe_dc_model(ckpt_path=hf_hub_download(repo_id="paulengstler/invisible-stitch", filename="invisible-stitch.pt")).to(device)
pipe = get_sd_pipeline(device)


def initialize_point_cloud(initial_image: Image.Image):
    w, h = initial_image.size

    R, T = look_at_view_transform(device=device, azim=0, elev=0, dist=0.01, at=torch.zeros((1, 3)))
    T *= 0
    cameras = PerspectiveCameras(R=R, T=T, focal_length=torch.tensor([w], dtype=torch.float32), principal_point=(((h-1)/2, (w-1)/2),), image_size=(initial_image.size,), device=device, in_ndc=False)

    # jumpstart the point cloud with a regular depth estimation
    t_initial_image = torch.from_numpy(np.asarray(initial_image)/255.).permute(2,0,1).float()
    aligned_depth = infer_with_zoe_dc(zoe_dc_model, t_initial_image, torch.zeros(h, w))

    # snap high gradients to nearest neighbor, which eliminates noodle artifacts
    aligned_depth = snap_high_gradients_to_nn(aligned_depth.to(device), threshold=12).cpu()
    xy_depth_world = project_points(cameras, aligned_depth)

    # centering
    center_point = torch.mean(xy_depth_world[0], dim=0, keepdim=True)
    xy_depth_world[0] = xy_depth_world[0] - center_point
    scene_dist = torch.norm(center_point, dim=-1).item()

    rgb = (torch.from_numpy(np.asarray(initial_image).copy()).reshape(-1, 3).float() / 255).to(device)
    point_cloud = get_pointcloud(xy_depth_world[0], device=device, features=rgb)

    return point_cloud, scene_dist


def extrapolate_point_cloud(
        prompt: str,
        azim_steps: list[float],
        dist: float,
        point_cloud: Pointclouds,
        dry_run: bool = False,
        discard_mask: bool = False,
        initial_image: Image.Image | None = None,
        depth_scaling: float = 1,
        seed: int = 0,
        **render_kwargs,
    ):
    w, h = initial_image.size
    optimization_bundle_frames = []

    generator = torch.Generator(device=pipe.device).manual_seed(seed)
    for azim in tqdm(azim_steps):
        R, T = look_at_view_transform(device=device, azim=azim, elev=0, dist=dist, at=torch.zeros((1, 3)))
        cameras = PerspectiveCameras(R=R, T=T, focal_length=torch.tensor([w], dtype=torch.float32), principal_point=(((h-1)/2, (w-1)/2),), image_size=(initial_image.size,), device=device, in_ndc=False)

        images, masks, depths = render(cameras, point_cloud, **render_kwargs)

        if not dry_run:
            eroded_mask = skimage.morphology.binary_erosion((depths[0] > 0).cpu().numpy(), footprint=None)
            eroded_depth = depths[0].clone()
            eroded_depth[torch.from_numpy(eroded_mask).to(depths.device) <= 0] = 0

            outpainted_img, aligned_depth = outpaint_with_depth_estimation(
                images[0],
                masks[0],
                eroded_depth,
                h,
                w,
                pipe,
                zoe_dc_model,
                prompt,
                cameras,
                dilation_size=2,
                depth_scaling=depth_scaling,
                generator=generator)

            aligned_depth = torch.from_numpy(aligned_depth).to(device)

        else:
            # in a dry run, we do not actually outpaint the image
            outpainted_img = Image.fromarray((255*images[0].cpu().numpy()).astype(np.uint8))

        if not dry_run:
            # snap high gradients to nearest neighbor, which eliminates noodle artifacts
            aligned_depth = snap_high_gradients_to_nn(aligned_depth.to(device), threshold=12).cpu()
            xy_depth_world = project_points(cameras, aligned_depth)

        c2w = cameras.get_world_to_view_transform().get_matrix()[0]

        optimization_bundle_frames.append({
            "rendered": Image.fromarray((255*images[0].cpu().numpy()).astype(np.uint8)),
            "image": outpainted_img,
            "mask": masks[0].cpu().numpy(),
            "transform_matrix": c2w.tolist(),
            "azim": azim,
            "elev": 0,
            "dist": dist,
        })

        if discard_mask:
            optimization_bundle_frames[-1].pop("mask")

        if not dry_run:
            optimization_bundle_frames[-1]["center_point"] = xy_depth_world[0].mean(dim=0).tolist()
            optimization_bundle_frames[-1]["depth"] = aligned_depth.cpu().numpy()
            optimization_bundle_frames[-1]["mean_depth"] = aligned_depth.mean().item()

        else:
            # in a dry run, we do not modify the point cloud
            continue

        rgb = (torch.from_numpy(np.asarray(outpainted_img).copy()).reshape(-1, 3).float() / 255).to(device)

        # pytorch 3d's mask might be slightly too big (subpixels), so we erode it a little to avoid seams
        # in theory, 1 pixel is sufficient but we use 2 to be safe
        masks[0] = torch.from_numpy(skimage.morphology.binary_erosion(masks[0].cpu().numpy(), footprint=skimage.morphology.disk(2))).to(device)
        if torch.any(~masks[0]):
            partial_outpainted_point_cloud = get_pointcloud(xy_depth_world[0][~masks[0].view(-1)], device=device, features=rgb[~masks[0].view(-1)])

            point_cloud = merge_pointclouds([point_cloud, partial_outpainted_point_cloud])

    return optimization_bundle_frames, point_cloud


if __name__ == "__main__":
    image_path = "examples/photo-1667788000333-4e36f948de9a.jpeg"
    prompt = "a street with traditional buildings in Kyoto, Japan"

    img = Image.open(image_path).convert("RGB")
    assert img.width % 8 == 0 and img.height % 8 == 0, "Image dimensions must be multiples of 8"

    step_size = 2
    total_steps = 25
    azim_steps = [step_size * i for i in range(total_steps)]

    point_cloud, scene_dist = initialize_point_cloud(img)
    print(f"{scene_dist=}, {point_cloud.points_padded()[0].mean(dim=0)=}")

    bundle_frames, point_cloud = extrapolate_point_cloud(
        prompt,
        azim_steps,
        dist=scene_dist,
        point_cloud=point_cloud,
        discard_mask=True,
        initial_image=img,
        depth_scaling=0.5,
        seed=0,
        fill_point_cloud_holes=True,
    )

    rendered_images = [bundle_frames[i]["rendered"] for i in range(len(bundle_frames))]
    generated_images = [bundle_frames[i]["image"] for i in range(len(bundle_frames))]
    export_to_gif(rendered_images, f"rendered_azim{step_size}_steps{total_steps}.gif", fps=10*step_size/25)
    export_to_gif(generated_images, f"generated_azim{step_size}_steps{total_steps}.gif", fps=10*step_size/25)