import pickle
import numpy as np
import mediapy as media

import torch
import torch.nn.functional as F

from prompt_tuning.gaussian_splatting import gs_render_batch
# from unidepth.models import UniDepthV2


def back_project_coords(depth_map, intrinsic):
    H, W = depth_map.shape
    device = depth_map.device

    # Create normalized pixel grid in [0, 1]
    u = torch.linspace(0, 1, W, device=device)
    v = torch.linspace(0, 1, H, device=device)
    uu, vv = torch.meshgrid(u, v, indexing='xy')  # shape: (W, H)

    pixels = torch.stack([uu, vv, torch.ones_like(uu)], dim=-1)  # (W, H, 3)  # (H, W, 3)

    # Apply inverse intrinsic
    K_inv = torch.inverse(intrinsic)
    rays = pixels @ K_inv.T  # (H, W, 3)

    # Scale by depth
    points_3d = rays * depth_map.unsqueeze(-1)  # (H, W, 3)

    return points_3d


def main():
    data_path = "data/webvid/DELTA_demo/dt3d_render/yellow-duck/dense_3d_track.pkl"
    with open(data_path, "rb") as f:
        data = pickle.load(f)
    for key in data.keys():
        data[key] = torch.Tensor(data[key]).to("cuda")
    with media.VideoReader(data_path.replace("dense_3d_track.pkl", "sparse_2d_track.mp4")) as reader:
        video_fps = reader.fps
    
    T = data["coords"].shape[0]
    H, W = [384, 512]
    H_ori, W_ori = [540, 960]
    if W_ori / W > H_ori / H:   # W占满, H更扁
        fx = 1
        fy = W_ori / H_ori / (W / H)
    else:
        fy = 1
        fx = H_ori / W_ori / (H / W)
    intrinsic = torch.Tensor([
            [fx, 0, 0.5],
            [0, fy, 0.5],
            [0, 0, 1]
        ]).to("cuda")
    normalized_3d_tracks = data["coords"] / torch.abs(data["coords"]).max()
    dense_3d_tracks = normalized_3d_tracks.unsqueeze(0)
    
    intrinsics = intrinsic.unsqueeze(0)
    extrinsic = torch.Tensor([
        [1, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 0, 1]
    ]).to("cuda")
    scale = 0.0001
    color_3d_tracks = data["colors"].unsqueeze(0)
    rendered_videos = gs_render_batch(
        intrinsics,
        extrinsic,
        [H, W],
        dense_3d_tracks,
        torch.Tensor([scale, scale, scale]),
        torch.Tensor([0.0, 0.0, 0.0, 1.0]),
        color_3d_tracks/255,
        torch.ones((H*W,))
    )
    # partial_video = F.interpolate(rendered_videos[0], (H_ori, W_ori), mode='bilinear')
    partial_video = rendered_videos[0].permute(0, 2, 3, 1).detach().cpu().numpy() * 255
    media.write_video(data_path.replace("dense_3d_track.pkl", "flow_render_normalize_minmax.mp4"), partial_video.astype(np.uint8), fps=video_fps)
    print("debug")
    


if __name__ == "__main__":
    main()