import os
import cv2
import sys
import math
import torch
import numpy as np
from torchvision.utils import save_image

sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../")))
from scene import Scene
from gaussian_splatting.gm_fluid import GaussianModel
from utils.graphics_utils import get_w2c
from helpers.helper_gaussian import get_model
from helpers.helper_parser import get_parser, write_args_to_file

from helpers.helper_pipe import get_render_pipe

from pdb import set_trace as bp

@torch.no_grad()
def visualize_velocity_on_image_cv2(image, coords, velocity, save_path, scale=1.0, step=10):
    # image: (1, H, W), torch tensor
    image_np = torch.clamp(image, 0.0, 1.0).squeeze().cpu().numpy()  # [H, W], float32
    image_np = (image_np * 255).clip(0, 255).astype(np.uint8) 

    image_bgr = cv2.cvtColor(image_np, cv2.COLOR_GRAY2BGR)  # [H, W, 3]

    coords_ds = coords[::step]
    velocity_ds = velocity[::step]

    H, W = image_np.shape
    for i in range(len(coords_ds)):
        x, y = coords_ds[i][:2].cpu().numpy()
        u, v = velocity_ds[i][:2].cpu().numpy()

        pt1 = (int(round(x)), int(round(y)))
        pt2 = (int(round(x + u * scale)), int(round(y + v * scale)))

        if 0 <= pt1[0] < W and 0 <= pt1[1] < H and 0 <= pt2[0] < W and 0 <= pt2[1] < H:
            cv2.arrowedLine(image_bgr, pt1, pt2, color=(255, 0, 0), thickness=1, tipLength=0.25)

    cv2.imwrite(save_path, image_bgr)
    print(f"Saved to {save_path}")

@torch.no_grad()
def main():
    args, model_args, op_extract, pp_extract = get_parser()
    gaussians: GaussianModel = get_model(model_args.model)()

    render_func, GRsetting, GRzer = get_render_pipe(pp_extract.rd_pipe)
    scene = Scene(model_args, gaussians, loader=model_args.loader)

    gaussians.setup_constants(op_extract)
    save_dir = os.path.join(scene.model_path, "vel_vis")
    os.makedirs(save_dir, exist_ok=True)

    physical_model_path = args.load_path
    physical_ckpt_path = os.path.join(physical_model_path, "checkpoint")

    visual_model_path = model_args.model_path
    visual_ckpt_path = os.path.join(visual_model_path, "checkpoint_level_two")


    cam_list = scene.get_train_cameras()

    train_cam_dict = {}
    unique_timestamps = sorted(list(set([cam.timestamp for cam in cam_list])))
    for i, timestamp in enumerate(unique_timestamps):
        train_cam_dict[i] = [cam for cam in cam_list if cam.timestamp == timestamp]

    for frame_id, viewpoint_list in train_cam_dict.items():
        # Load visual particles
        gaussians.load_visual(visual_ckpt_path, frame_id, scale=False)
        gaussians.load_hidden(physical_ckpt_path, frame_id)

        if not hasattr(gaussians, "_velocity"):
            print(f"[Warning] Frame {frame_id} has no _velocity. Skipping.")
            continue

        pos = gaussians._xyz            # (N, 3)
        vel = gaussians._velocity       # (N, 3)

        for cam in viewpoint_list:
            # image = cam.original_image.to("cuda")  # (3,H,W)
            rendered = render_func(
                cam,
                gaussians,
                pp_extract,
                torch.tensor([1.0, 1.0, 1.0], device="cuda") if model_args.white_background else torch.tensor([0.0, 0.0, 0.0], device="cuda"),
                GRsetting=GRsetting,
                GRzer=GRzer,
                pos_type="visual"
            )
            image = rendered["render"]

            # import pdb; pdb.set_trace()
            w2c = torch.tensor(np.eye(4, dtype=np.float32), device="cuda")
            w2c[:3, :3] = torch.tensor(cam.R, device="cuda")
            w2c[:3, 3] = torch.tensor(cam.T, device="cuda")
            H, W = cam.image_height, cam.image_width
            FovX = cam.FoVx if isinstance(cam.FoVx, float) else cam.FoVx.item()
            FovY = cam.FoVy if isinstance(cam.FoVy, float) else cam.FoVy.item()

            fx = 0.5 * W / math.tan(0.5 * FovX)
            fy = 0.5 * H / math.tan(0.5 * FovY)
            cx = W / 2
            cy = H / 2

            K = torch.tensor([
                [fx, 0,  cx],
                [0,  fy, cy],
                [0,  0,   1]
            ], dtype=torch.float32, device="cuda")

            pos_homo = torch.cat([pos, torch.ones_like(pos[:, :1])], dim=-1).T  # [4, N]
            cam_coords = w2c.to(pos.device) @ pos_homo                          # [4, N]
            cam_coords = cam_coords[:3]

            x = (K[0, 0] * cam_coords[0] / cam_coords[2]) + K[0, 2]
            y = (K[1, 1] * cam_coords[1] / cam_coords[2]) + K[1, 2]
            xy = torch.stack([x, y], dim=-1)

            save_path = os.path.join(save_dir, f"vel_{cam.image_name}")
            visualize_velocity_on_image_cv2(image, xy, vel, save_path, scale=3.0, step=10)
            print(f"Saved velocity image: {save_path}")

if __name__ == "__main__":
    main()