import os
import json
import argparse
import numpy as np
import torch

def estimate_scene_center(c2ws):
    rays_o = np.array([c2w[:3, 3] for c2w in c2ws])
    rays_d = np.array([-c2w[:3, 2] for c2w in c2ws])  # -Z forward

    A = np.eye(3)
    b = np.zeros(3)
    for o, d in zip(rays_o, rays_d):
        d = d / np.linalg.norm(d)
        I = np.eye(3)
        A += I - np.outer(d, d)
        b += (I - np.outer(d, d)) @ o

    center = np.linalg.inv(A) @ b
    return center

def rotate_y_up(vec, angle_deg):
    theta = np.radians(angle_deg)
    cos, sin = np.cos(theta), np.sin(theta)
    rot_y = np.array([
        [cos,  0, sin],
        [0,    1,  0 ],
        [-sin, 0, cos]
    ])
    return rot_y @ vec

def generate_rotated_camera_yup(c2w, scene_center, angle_deg=-30):
    origin = c2w[:3, 3]
    rel_vec = origin - scene_center
    new_rel_vec = rotate_y_up(rel_vec, angle_deg)
    new_origin = scene_center + new_rel_vec

    # up = np.array([0, 1, 0], dtype=np.floats32)
    
    up = c2w[:3, 1]

    z = scene_center - new_origin
    z = z / np.linalg.norm(z)
    x = np.cross(up, z)
    x = x / np.linalg.norm(x)
    y = np.cross(z, x)

    new_c2w = np.eye(4)
    new_c2w[:3, 0] = x
    new_c2w[:3, 1] = up
    new_c2w[:3, 2] = z
    new_c2w[:3, 3] = new_origin

    return new_c2w

def rotate_all_poses_yup(c2ws, scene_center, angle_deg=30):
    return [generate_rotated_camera_yup(c2w, scene_center, angle_deg) for c2w in c2ws]

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_root", required=True, help="Root directory containing duster_out")
    args = parser.parse_args()

    out_dir = os.path.join(args.data_root, "duster_out")
    poses  = torch.load(f"{out_dir}/poses.pt", weights_only=True)
    focals = torch.load(f"{out_dir}/focals.pt", weights_only=True)
    pts3d  = torch.load(f"{out_dir}/pts3d.pt", weights_only=True)
    conf   = torch.load(f"{out_dir}/confidence_masks.pt", weights_only=True)
    ks     = torch.load(f"{out_dir}/intrinsics.pt", weights_only=True)
    fg     = torch.load(f"{out_dir}/fg_pts.pt", weights_only=True)
    bg     = torch.load(f"{out_dir}/bg_pts.pt", weights_only=True)

    # scene_center = estimate_scene_center(poses.detach().cpu().numpy())

    scene_center = np.array(torch.cat((fg, bg)).mean(0))
    new_poses = []
    for angles_deg in [10, 20, 30, -10]:
        rotated_c2ws = np.array(rotate_all_poses_yup(
            poses.detach().cpu().numpy(), scene_center, angle_deg=angles_deg))
        new_poses.append(rotated_c2ws)

    num_frames = 300
    image_h, image_w = 1080, 1920
    near = 50.0
    far  = 8000.0
    gen_dir = 'frames_cam'

    frames = []
    # test
    focals = torch.zeros(focals.shape) + 1920.0

    for i in range(240):
        pose_v = poses[i].detach().cpu().numpy()
        frame = {
            "file_path": f"{gen_dir}00/frame_{i:03d}.png",
            "camera_hw": [image_h, image_w],
            "camera_angle_x":  2*np.arctan(image_w / (2 * focals[i].detach().cpu().numpy())).item(),
            "transform_matrix": pose_v.tolist(),
            "frame_idx": i,
            "focal": focals[i].item(),
            "fx": focals[i].item(),
            "cx": ks[i, 0, 2].item(),
            "fy": focals[i].item(),
            "cy": ks[i, 1, 2].item(),
        }
        frames.append(frame)
    train_json = {
        "near": near,
        "far": far,
        "frame_bkg_color": [0.0, 0.0, 0.0],
        "frames": frames,
        "fg_pts": os.path.join(args.data_root, "duster_out", "fg_pts.pt"),
        "bg_pts": os.path.join(args.data_root, "duster_out", "bg_pts.pt"),
    }
    train_path = os.path.join(args.data_root, "transforms_train_wild_smoke.json")
    with open(train_path, "w") as f:
        json.dump(train_json, f, indent=4)
    print(f"Saved to {train_path}")

    for i in range(240):
        for v in range(1, len(new_poses) + 1):
            frame_gen = {
                "file_path": f"{gen_dir}{v:02d}/frame_{i:03d}.png",
                "camera_hw": [image_h, image_w],
                "camera_angle_x":  2*np.arctan(image_w / (2 * focals[i].detach().cpu().numpy())).item(),
                "transform_matrix": new_poses[v-1][i].tolist(),
                "frame_idx": i,
                "focal": focals[i].item(),
                "fx": focals[i].item(),
                "cx": ks[i, 0, 2].item(),
                "fy": focals[i].item(),
                "cy": ks[i, 1, 2].item(),
            }
            frames.append(frame_gen)

    train_json = {
        "near": near,
        "far": far,
        "frame_bkg_color": [0.0, 0.0, 0.0],
        "frames": frames,
        "fg_pts": os.path.join(args.data_root, "duster_out", "fg_pts.pt"),
        "bg_pts": os.path.join(args.data_root, "duster_out", "bg_pts.pt"),
    }
    train_path = os.path.join(args.data_root, "transforms_train_wild_smoke_sv4d.json")
    with open(train_path, "w") as f:
        json.dump(train_json, f, indent=4)
    print(f"Saved to {train_path}")

    frames = []
    for i in range(240, 270):
        frame = {
            "file_path": f"{gen_dir}00/frame_{i:03d}.png",
            "camera_hw": [image_h, image_w],
            "camera_angle_x":  2*np.arctan(image_w / (2 * focals[i].detach().cpu().numpy())).item(),
            "transform_matrix": poses[i].tolist(),
            "frame_idx": i,
            "focal": focals[i].item(),
            "fx": focals[i].item(),
            "cx": ks[i, 0, 2].item(),
            "fy": focals[i].item(),
            "cy": ks[i, 1, 2].item(),
        }
        frames.append(frame)

    test_json = {
        "near": near,
        "far": far,
        "frame_bkg_color": [0.0, 0.0, 0.0],
        "frames": frames,
        "fg_pts": os.path.join(args.data_root, "duster_out", "fg_pts.pt"),
        "bg_pts": os.path.join(args.data_root, "duster_out", "bg_pts.pt"),
    }
    test_path = os.path.join(args.data_root, "transforms_test_wild_smoke.json")
    with open(test_path, "w") as f:
        json.dump(test_json, f, indent=4)
    print(f"Saved to {test_path}")
    

if __name__ == "__main__":
    main()