import os
import cv2
import sys
import math
import torch
import numpy as np
from torchvision.utils import save_image

sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../")))
from scene import Scene
from gaussian_splatting.gm_fluid import GaussianModel
from utils.graphics_utils import get_w2c
from helpers.helper_gaussian import get_model
from helpers.helper_parser import get_parser, write_args_to_file

from helpers.helper_pipe import get_render_pipe

from pdb import set_trace as bp

import torch.nn.functional as F
import torchvision


@torch.no_grad()
def main():
    args, model_args, op_extract, pp_extract = get_parser()
    gaussians: GaussianModel = get_model(model_args.model)()

    render_func, GRsetting, GRzer = get_render_pipe(pp_extract.rd_pipe)
    scene = Scene(model_args, gaussians, loader=model_args.loader)

    gaussians.setup_constants(op_extract)
    save_dir = os.path.join(scene.model_path, "other_poses")
    os.makedirs(save_dir, exist_ok=True)

    physical_model_path = args.load_path
    physical_ckpt_path = os.path.join(physical_model_path, "checkpoint")

    visual_model_path = model_args.model_path
    visual_ckpt_path = os.path.join(visual_model_path, "checkpoint_level_two")


    cam_list = scene.get_train_cameras()

    train_cam_dict = {}
    unique_timestamps = sorted(list(set([cam.timestamp for cam in cam_list])))
    for i, timestamp in enumerate(unique_timestamps):
        train_cam_dict[i] = [cam for cam in cam_list if cam.timestamp == timestamp]

    # for frame_id, viewpoint_list in train_cam_dict.items():
        # Load visual particles
    frame_id = 0
    gaussians.load_visual(visual_ckpt_path, frame_id, scale=False)
    gaussians.load_hidden(physical_ckpt_path, frame_id)

    if not hasattr(gaussians, "_velocity"):
        print(f"[Warning] Frame {frame_id} has no _velocity. Skipping.")
        # continue

    pos = gaussians._xyz            # (N, 3)
    vel = gaussians._velocity       # (N, 3)
    
    base_cam = train_cam_dict[0][0]
    center = gaussians._xyz.mean(dim=0)   # 烟雾中心 (3,)

    device = center.device
    angle_range_deg = 120
    n_views = 120
    angles = torch.linspace(-angle_range_deg / 2, angle_range_deg / 2, steps=n_views) * (math.pi / 180)

    base_c2w = base_cam.world_view_transform.to(device)   # (4,4)
    cam_pos = base_c2w[:3, 3]
    up = base_c2w[:3, 1]   # 假设 y 是 up 向量

    radius = torch.norm(cam_pos - center)
    initial_angle = torch.atan2(cam_pos[0] - center[0], cam_pos[2] - center[2])  # 水平初始角

    for i, delta_angle in enumerate(angles):
        theta = initial_angle + delta_angle
        new_pos = center + radius * torch.tensor([
            math.sin(theta), 0.0, math.cos(theta)
        ], device=device)

        forward = F.normalize(center - new_pos, dim=0)
        right = F.normalize(torch.cross(up, forward), dim=0)
        true_up = torch.cross(forward, right)

        R = torch.stack([right, true_up, forward], dim=1)  # (3,3)
        c2w = torch.eye(4, device=device)
        c2w[:3, :3] = R
        c2w[:3, 3] = new_pos

        # 设置当前cam
        base_cam.world_view_transform = c2w
        
        rendered = render_func(
            base_cam,
            gaussians,
            pp_extract,
            torch.tensor([1.0, 1.0, 1.0], device="cuda") if model_args.white_background else torch.tensor([0.0, 0.0, 0.0], device="cuda"),
            GRsetting=GRsetting,
            GRzer=GRzer,
            pos_type="visual"
        )
        image = rendered["render"]
        torchvision.utils.save_image(image, os.path.join(save_dir, f"frame_{frame_id:03d}_orbit_{i:03d}.png"))
if __name__ == "__main__":
    main()