import numpy as np 
import torch 
from scipy.spatial.transform import Rotation
import os 
import json 

def get_viewmat(optimized_camera_to_world):
    """
    function that converts c2w to gsplat world2camera matrix, using compile for some speed
    """
    R = optimized_camera_to_world[:, :3, :3]  # 3 x 3
    T = optimized_camera_to_world[:, :3, 3:4]  # 3 x 1
    # flip the z and y axes to align with gsplat conventions
    R = R * torch.tensor([[[1, -1, -1]]], device=R.device, dtype=R.dtype)
    # analytic matrix inverse to get world2camera matrix
    R_inv = R.transpose(1, 2)
    T_inv = -torch.bmm(R_inv, T)
    viewmat = torch.zeros(R.shape[0], 4, 4, device=R.device, dtype=R.dtype)
    viewmat[:, 3, 3] = 1.0  # homogenous
    viewmat[:, :3, :3] = R_inv
    viewmat[:, :3, 3:4] = T_inv
    return viewmat

def camera_extrinsics_from_vtk(C, F, up, z_is_forward=True):
    """
    Given a vtk.vtkCamera (e.g. from PyVista's plotter.camera),
    return a rotation matrix R (3x3) and translation vector t (3,).
    This defines the transform from world coords to camera coords.
    
    By default, we define +Z as 'forward' from the camera to the focal point.
    If you want +Z to be 'backward' (like typical OpenGL -Z forward),
    set z_is_forward=False.
    """
    # Get position (camera center), focal point, and view-up from the VTK camera
    # C = np.array(camera.GetPosition())    # shape (3,)
    # F = np.array(camera.GetFocalPoint())  # shape (3,)
    # up = np.array(camera.GetViewUp())     # shape (3,)

    # 1) Forward vector: from camera to the focal point
    f = F - C
    f /= np.linalg.norm(f)  # normalize

    # If you want +Z = forward, keep it as-is
    # If you want +Z = backward (like typical OpenGL), invert it:
    if not z_is_forward:
        f = -f

    # 2) Right vector
    r = np.cross(f, up)
    r /= np.linalg.norm(r)

    # 3) True up vector
    u = np.cross(r, f)
    u /= np.linalg.norm(u)

    # Construct rotation matrix R so that:
    # R[0,:] = r, R[1,:] = u, R[2,:] = f
    # That means each row is one of the basis vectors.
    R = np.array([r, u, f])  # shape (3,3)

    # 4) Compute translation so that X_cam = R * (X_world - C).
    # Typically, extrinsic is [R | -R*C], but we can store t = - R*C:
    t = -R @ C

    return R, t

def generate_poses_translate_single(pose, target, dt, start_angle = 0, interval=0.0001, translation_range=2, dim=1):
    transform = np.array(dt['transform'])
    transform_ap = np.vstack((transform, np.array([0,0,0,1])))
    scale = dt['scale']

    C = np.array(pose, dtype=np.float64)
    F = np.array(target)
    direction = F-C 
    r = np.linalg.norm(direction)
    direction = direction/np.linalg.norm(direction)
    yaw = np.arctan2(direction[1], direction[0])+start_angle
    pitch = -np.arcsin(direction[2])

    dim_range = np.arange(0, translation_range, interval)
    camera_poses = pose[:,None].repeat(len(dim_range),1).astype(np.float64)
    camera_poses[dim]+=dim_range 

    res = []

    for i in range(camera_poses.shape[1]):
        eye = camera_poses[:,i]
        half_dist = interval/2

        rotation_matrix = Rotation.from_euler('xyz',[0,pitch,yaw]).as_matrix()
        R = Rotation.from_euler('yzx',[np.pi/2, 0, -np.pi/2]).as_matrix()
        rot_matrix = rotation_matrix@R

        trans_mat1 = np.zeros((4,4))
        trans_mat1[:3,:3] = rot_matrix
        trans_mat1[:3,3] = eye 
        trans_mat1[dim,3] += half_dist
        trans_mat1[3,3] = 1
        # camera_poses_ub.append(trans_mat1.copy())

        trans_mat1[0:3, 1:3] *= -1
        trans_mat1 = trans_mat1[np.array([0, 2, 1, 3]), :]
        trans_mat1[2, :] *= -1

        tmp = np.linalg.inv(np.array([
            [1,0,0,0],
            [0,1,0,0],
            [0,0,1,0],
            [0,0,0,1]
        ]))

        camera_pose_transformed1 = tmp@transform_ap@trans_mat1
        camera_pose_transformed1 = camera_pose_transformed1[:3,:]
        camera_pose_transformed1[:3,3] *= scale 
        view_mat1 = get_viewmat(torch.Tensor(camera_pose_transformed1)[None])
        res.append(view_mat1)

    return res

def generate_poses_translate(pose, target, dt, start_angle = 0, interval=0.0001, translation_range=2, dim=1):
    transform = np.array(dt['transform'])
    transform_ap = np.vstack((transform, np.array([0,0,0,1])))
    scale = dt['scale']

    C = np.array(pose, dtype=np.float64)
    F = np.array(target)
    direction = F-C 
    r = np.linalg.norm(direction)
    direction = direction/np.linalg.norm(direction)
    yaw = np.arctan2(direction[1], direction[0])+start_angle
    pitch = -np.arcsin(direction[2])

    dim_range = np.arange(0, translation_range, interval)
    camera_poses = pose[:,None].repeat(len(dim_range),1).astype(np.float64)
    camera_poses[dim]+=dim_range 
    res_lb = []
    res_ub = []

    camera_poses_lb = []
    camera_poses_ub = []

    for i in range(camera_poses.shape[1]):
        eye = camera_poses[:,i]
        half_dist = interval/2

        rotation_matrix = Rotation.from_euler('xyz',[0,pitch,yaw]).as_matrix()
        R = Rotation.from_euler('yzx',[np.pi/2, 0, -np.pi/2]).as_matrix()
        rot_matrix = rotation_matrix@R

        trans_mat1 = np.zeros((4,4))
        trans_mat1[:3,:3] = rot_matrix
        trans_mat1[:3,3] = eye 
        trans_mat1[dim,3] += half_dist
        trans_mat1[3,3] = 1
        camera_poses_ub.append(trans_mat1.copy())

        trans_mat1[0:3, 1:3] *= -1
        trans_mat1 = trans_mat1[np.array([0, 2, 1, 3]), :]
        trans_mat1[2, :] *= -1

        trans_mat2 = np.zeros((4,4))
        trans_mat2[:3,:3] = rot_matrix
        trans_mat2[:3,3] = eye 
        trans_mat2[dim,3] -= half_dist
        trans_mat2[3,3] = 1
        camera_poses_lb.append(trans_mat2.copy())

        trans_mat2[0:3, 1:3] *= -1
        trans_mat2 = trans_mat2[np.array([0, 2, 1, 3]), :]
        trans_mat2[2, :] *= -1

        tmp = np.linalg.inv(np.array([
            [1,0,0,0],
            [0,1,0,0],
            [0,0,1,0],
            [0,0,0,1]
        ]))

        camera_pose_transformed1 = tmp@transform_ap@trans_mat1
        camera_pose_transformed1 = camera_pose_transformed1[:3,:]
        camera_pose_transformed1[:3,3] *= scale 
        view_mat1 = get_viewmat(torch.Tensor(camera_pose_transformed1)[None])


        camera_pose_transformed2 = tmp@transform_ap@trans_mat2
        camera_pose_transformed2 = camera_pose_transformed2[:3,:]
        camera_pose_transformed2[:3,3] *= scale 
        view_mat2 = get_viewmat(torch.Tensor(camera_pose_transformed2)[None])

        view_mat_lb = torch.zeros((1,4,4))
        view_mat_ub = torch.zeros((1,4,4))
        view_mat_lb[:,:3,:3] = view_mat1[:,:3,:3]
        view_mat_ub[:,:3,:3] = view_mat1[:,:3,:3]

        view_mat_lb[:,:3,3] = torch.minimum(view_mat1[:,:3,3], view_mat2[:,:3,3])
        view_mat_ub[:,:3,3] = torch.maximum(view_mat1[:,:3,3], view_mat2[:,:3,3])

        view_mat_lb[:,3,3] = 1
        view_mat_ub[:,3,3] = 1

        res_lb.append(view_mat_lb)
        res_ub.append(view_mat_ub)

    return res_lb, res_ub, camera_poses, (np.array(camera_poses_lb), np.array(camera_poses_ub))

def generate_poses_single(pose, target, dt, interval=0.005, start_angle = 0, angle_range=np.pi*2):
    transform = np.array(dt['transform'])
    transform_ap = np.vstack((transform, np.array([0,0,0,1])))
    scale = dt['scale']

    C = np.array(pose)
    F = np.array(target)
    direction = F-C 
    r = np.linalg.norm(direction)
    direction = direction/np.linalg.norm(direction)
    yaw = np.arctan2(direction[1], direction[0])
    pitch = -np.arcsin(direction[2])

    yaw_angles = np.arange(yaw, yaw+angle_range, interval)+start_angle

    res = []
    for new_yaw in yaw_angles:
        eye = F-np.array([r*np.cos(-pitch)*np.cos(new_yaw),r*np.cos(-pitch)*np.sin(new_yaw),r*np.sin(-pitch)])
        direction = target-eye 
        direction = direction/np.linalg.norm(direction)

        rotation_matrix = Rotation.from_euler('xyz',[0,pitch,new_yaw]).as_matrix()
        R = Rotation.from_euler('yzx',[np.pi/2, 0, -np.pi/2]).as_matrix()
        rot_matrix = rotation_matrix@R

        trans_mat1 = np.zeros((4,4))
        trans_mat1[:3,:3] = rot_matrix
        trans_mat1[:3,3] = eye
        trans_mat1[3,3] = 1

        trans_mat1[0:3, 1:3] *= -1
        trans_mat1 = trans_mat1[np.array([0, 2, 1, 3]), :]
        trans_mat1[2, :] *= -1
        tmp = np.linalg.inv(np.array([
            [1,0,0,0],
            [0,1,0,0],
            [0,0,1,0],
            [0,0,0,1]
        ]))

        camera_pose_transformed1 = tmp@transform_ap@trans_mat1
        camera_pose_transformed1 = camera_pose_transformed1[:3,:]
        camera_pose_transformed1[:3,3] *= scale 
        view_mat1 = get_viewmat(torch.Tensor(camera_pose_transformed1)[None])
        res.append(view_mat1)
    return res 


def generate_poses(pose, target, interval=0.005, start_angle = 0, angle_range=np.pi*2):

    script_dir = os.path.dirname(os.path.realpath(__file__))
    output_folder = os.path.join(script_dir, '../../../../nerfstudio/outputs/airplane_sampled/splatfacto/2025-04-16_162206')

    transform_fn = os.path.join(output_folder, 'dataparser_transforms.json')

    # json_fn = './dataparser_transforms.json'
    with open(transform_fn, 'r') as fp:
        dt = json.load(fp)

    transform = np.array(dt['transform'])
    transform_ap = np.vstack((transform, np.array([0,0,0,1])))
    scale = dt['scale']

    C = np.array(pose)
    F = np.array(target)
    direction = F-C 
    r = np.linalg.norm(direction)
    direction = direction/np.linalg.norm(direction)
    yaw = np.arctan2(direction[1], direction[0])
    pitch = -np.arcsin(direction[2])

    yaw_angles = np.arange(yaw, yaw+angle_range, interval)+start_angle
    res_lb = []
    res_ub = []

    camera_poses_lb = []
    camera_poses_ub = []

    for new_yaw in yaw_angles:
        eye = F-np.array([r*np.cos(-pitch)*np.cos(new_yaw),r*np.cos(-pitch)*np.sin(new_yaw),r*np.sin(-pitch)])
        direction = target-eye 
        direction = direction/np.linalg.norm(direction)

        offset_vec = np.array([-direction[1], direction[0], 0])
        half_dist = np.sin(interval/2)*r

        rotation_matrix = Rotation.from_euler('xyz',[0,pitch,new_yaw]).as_matrix()
        R = Rotation.from_euler('yzx',[np.pi/2, 0, -np.pi/2]).as_matrix()
        rot_matrix = rotation_matrix@R

        trans_mat1 = np.zeros((4,4))
        trans_mat1[:3,:3] = rot_matrix
        trans_mat1[:3,3] = eye+offset_vec*half_dist 
        trans_mat1[3,3] = 1
        camera_poses_ub.append(trans_mat1.copy())

        trans_mat1[0:3, 1:3] *= -1
        trans_mat1 = trans_mat1[np.array([0, 2, 1, 3]), :]
        trans_mat1[2, :] *= -1

        trans_mat2 = np.zeros((4,4))
        trans_mat2[:3,:3] = rot_matrix
        trans_mat2[:3,3] = eye-offset_vec*half_dist 
        trans_mat2[3,3] = 1
        camera_poses_lb.append(trans_mat2.copy())

        trans_mat2[0:3, 1:3] *= -1
        trans_mat2 = trans_mat2[np.array([0, 2, 1, 3]), :]
        trans_mat2[2, :] *= -1

        tmp = np.linalg.inv(np.array([
            [1,0,0,0],
            [0,1,0,0],
            [0,0,1,0],
            [0,0,0,1]
        ]))

        camera_pose_transformed1 = tmp@transform_ap@trans_mat1
        camera_pose_transformed1 = camera_pose_transformed1[:3,:]
        camera_pose_transformed1[:3,3] *= scale 
        view_mat1 = get_viewmat(torch.Tensor(camera_pose_transformed1)[None])


        camera_pose_transformed2 = tmp@transform_ap@trans_mat2
        camera_pose_transformed2 = camera_pose_transformed2[:3,:]
        camera_pose_transformed2[:3,3] *= scale 
        view_mat2 = get_viewmat(torch.Tensor(camera_pose_transformed2)[None])

        view_mat_lb = torch.zeros((1,4,4))
        view_mat_ub = torch.zeros((1,4,4))
        view_mat_lb[:,:3,:3] = view_mat1[:,:3,:3]
        view_mat_ub[:,:3,:3] = view_mat1[:,:3,:3]

        view_mat_lb[:,:3,3] = torch.minimum(view_mat1[:,:3,3], view_mat2[:,:3,3])
        view_mat_ub[:,:3,3] = torch.maximum(view_mat1[:,:3,3], view_mat2[:,:3,3])

        view_mat_lb[:,3,3] = 1
        view_mat_ub[:,3,3] = 1

        res_lb.append(view_mat_lb)
        res_ub.append(view_mat_ub)

    return res_lb, res_ub, (np.array(camera_poses_lb), np.array(camera_poses_ub))

if __name__ == "__main__":
    res_lb, res_ub = generate_poses(np.array([40*np.sqrt(2)/2, 0, 40*np.sqrt(2)/2]), np.array([0,0,0]))