import time
from collections import OrderedDict
import math
import torch
import numpy as np
import matplotlib.pyplot as plt

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from pytransform3d import rotations as pr
import cv2
from pytorch3d import transforms
from einops import rearrange, reduce, repeat
try:
    # 尝试导入函数
    from helperfunctions.hfunctions import assert_torch_invalid
except:
    # 如果导入失败，定义一个assert_torch_invalid函数
    def assert_torch_invalid(X, string):
        # 检查是否存在NaN
        assert not torch.isnan(X).any(), print('NaN problem ['+string+']')
        # 检查是否存在inf
        assert not torch.isinf(X).any(), print('inf problem ['+string+']')
        # 检查所有元素是否有限
        assert torch.all(torch.isfinite(X)), print('Some elements not finite problem ['+string+']')
        # 检查张量是否为空
        assert X.numel() > 0, print('Empty tensor problem ['+string+']')


# 语义渲染，三维眼球在二维平面上进行投影，得到多个参数
def render_semantics(out_dict, H, W, args, data_dict=None):
    '''
        t是3D眼球中心位置
        R_pupil:是从3D到相机帧的旋转矩阵(用滚动、偏转、俯仰表示)
        R_iris:是从3D到相机帧的旋转矩阵(用滚动、偏转、俯仰表示)
        r_pupil:是瞳孔的半径
        r_iris:是虹膜半径
        L_pupil:眼球到瞳孔中心的距离
        L_iris:眼球到虹膜中心的距离
        focal: [f_x, f_y]
        L:眼球半径
        L_P:眼球半径和虹膜半径的平方差
    '''

    T=out_dict['T'] 
    R=out_dict['R']
    r_pupil=out_dict['r_pupil'] 
    r_iris=out_dict['r_iris']
    L=out_dict['L'] 
    focal=out_dict['focal']

    # Scale and bound predictions
    # 缩放和限制预测结果在合理范围内
    T, R, r_pupil, r_iris, L, L_p, focal = scale_and_bound(T, R, r_pupil, r_iris, L, focal, args)

    # Tensor shape compatibility
    # 张量形状兼容性
    T, R, r_pupil, r_iris, L, L_p, focal, iterations = tensor_shape_compatibility(T, R, r_pupil, r_iris, L, L_p, focal, args)

    # Euler angles [deg] to rotation matrix
    # 欧拉角[deg]转换为旋转矩阵
    Rotation = euler_to_rotation(R)

    # Generate 3D pupil and iris template pointclouds # 生成3D瞳孔和虹膜模板点云
    N_angles = args['temp_n_angles']
    N_radius = args['temp_n_radius']
    # 生成的瞳孔和虹膜的模板三维点云坐标，以及边缘点索引
    pupil_XYZ, iris_XYZ, edge_idx_iris, edge_idx_pupil = template_generation(N_angles, 
                                                                             N_radius, 
                                                                             r_pupil, 
                                                                             r_iris, L_p, 
                                                                             args=args)

    # 生成的瞳孔和虹膜的三维点坐标，以及投影位置
    pupil_3D, pupil_UV, iris_3D, iris_UV = project_templates_to_2D(pupil_XYZ, iris_XYZ, Rotation, T, 
                                                      focal[:, 0], focal[:, 0], W/2, H/2)
    iris_c_3D = iris_3D
    pupil_c_3D = pupil_3D[:,0,:]
    pupil_c_UV = pupil_UV[:,0,:]

    eyeball_c_3D, eyeball_c_UV = extrinsics_project(torch.zeros(T.shape[0], 1, 3).to(T.device), 
                                                    Rotation, T, focal[:, 0], 
                                                    focal[:, 0], W/2, H/2)   
    eyeball_c_3D = eyeball_c_3D.squeeze(1)
    eyeball_c_UV = eyeball_c_UV.squeeze(1)

    # # Gaze vector is a unit vector going from the eyeball center to the pupil circle center
    # temp_gaze_vector_3D = (pupil_c_3D - eyeball_c_3D)
    # gaze_vector_3D = temp_gaze_vector_3D / \
    #             (torch.norm(temp_gaze_vector_3D, dim=-1, keepdim=True) + 1e-9)
    
    # The gaze vector in our eye model coordinates is: [0, 0, -1]
    # The gaze vector in our camera coordinates is: R @ [0, 0, -1] = -Rotation[:, 2]
    # (R|t convert from our eye coordinates to our camera coordinates)
    # 注视向量是从眼球中心到瞳孔圆心的单位向量
    gaze_vector_3D = -Rotation[:, :, 2]
    # We aditionally need to flip the z-axis to convert 
    # from our camera coordiantes to the coordinate system of Wolfgang dataset
    # (Wolfgang coordinates are same as our camera coordiantes, just z is flipped to look towards the camera)
    gaze_vector_3D[:, 2] = -gaze_vector_3D[:, 2]

    # 将注视向量从摄像机坐标系转换为UV坐标系
    temp_gaze_vector_UV = (pupil_c_UV - eyeball_c_UV)
    gaze_vector_UV = temp_gaze_vector_UV / \
                (torch.norm(temp_gaze_vector_UV, dim=-1, keepdim=True) + 1e-9)

    #define the yaw and pitch rotation vector in radian. Mainly for NVGaze gaze vector
    # 定义弯曲和俯仰旋转向量为弧度。主要用于NVGaze注视向量
    rotation = 80 * R * torch.tensor([[0,1,1]]).to(R.device)
    rotation_rad = rotation * math.pi / 180
    rotation_rad = torch.stack((rotation_rad[...,2],
                                rotation_rad[...,1]),
                                axis=1)
    # 更新out_dict中的键值对
    out_dict['T'] = T
    out_dict['R'] = R
    out_dict['r_pupil'] = r_pupil
    out_dict['r_iris'] = r_iris
    out_dict['L'] = L
    out_dict['L_p'] = L_p
    out_dict['focal'] = focal

    rend_dict = {
        # 'pupil_3D': pupil_3D,
        'pupil_UV': pupil_UV,
        # 'iris_3D':iris_3D
        'iris_UV': iris_UV,
        'edge_idx_iris': edge_idx_iris,
        'edge_idx_pupil': edge_idx_pupil,
        # 'pupil_c_3D': pupil_c_3D,
        'pupil_c_UV': pupil_c_UV,
        # 'iris_c_3D': iris_c_3D,
        # 'eyeball_c_3D': eyeball_c_3D,
        'eyeball_c_UV': eyeball_c_UV,
        'gaze_vector_3D': gaze_vector_3D,
        'gaze_vector_UV': gaze_vector_UV,
        'rotation_rad': rotation_rad
    }
    # print(out_dict)
    # eye_model_visualize(T[0][0], T[0][1], T[0][2], R[0][0], R[0][1], R[0][2], r_pupil[0], r_iris[0], L[0], focal[0][0], focal[0][1])
    # # TODO  TODO  TODO  TODO  TODO  TODO  TODO  TODO  TODO  TODO  TODO  TODO
    # fig1 = plt.figure(figsize=plt.figaspect(1.))
    # ax_1 = fig1.add_subplot(1, 1, 1)
    #
    # i = 3
    #
    # ax_1.imshow((data_dict['image'][i]), alpha=1, cmap='gray')
    #
    # ax_1.scatter(pupil_UV[i, :,0].detach().cpu().numpy(), pupil_UV[i, :,1].detach().cpu().numpy(), marker='x', color='yellow', alpha=0.8, s=1**2)
    # ax_1.scatter(iris_UV[i, :,0].detach().cpu().numpy(), iris_UV[i, :,1].detach().cpu().numpy(), marker='o', color='green', alpha=0.5, s=1**2)
    # # ax_1.scatter(pupil_UV[i, edge_idx_pupil['left'],0].detach().cpu().numpy(), pupil_UV[i, edge_idx_pupil['left'],1].detach().cpu().numpy(), marker='x', color='red', alpha=0.8, s=3**2)
    # # ax_1.scatter(pupil_UV[i, edge_idx_pupil['right'],0].detach().cpu().numpy(), pupil_UV[i, edge_idx_pupil['right'],1].detach().cpu().numpy(), marker='x', color='red', alpha=0.8, s=3**2)
    # # ax_1.scatter(iris_UV[i, edge_idx_iris['left'],0].detach().cpu().numpy(), iris_UV[i, edge_idx_iris['left'],1].detach().cpu().numpy(), marker='o', color='red', alpha=0.5, s=3**2)
    # # ax_1.scatter(iris_UV[i, edge_idx_iris['right'],0].detach().cpu().numpy(), iris_UV[i, edge_idx_iris['right'],1].detach().cpu().numpy(), marker='o', color='red', alpha=0.5, s=3**2)
    # # ax_1.plot([0, W, W, 0, 0], [0, 0, H, H, 0], color='black')
    # # ax_1.plot([0, W], [H/2, H/2], color='gray')
    # # ax_1.plot([W/2, W/2], [0, H], color='black')
    # # ax_1.set_xlabel('X Label')
    # # ax_1.set_ylabel('Y Label')
    # # ax_1.set_axis_off()
    # plt.savefig('image_1.jpg', bbox_inches='tight')
    # plt.show()
    # # ax_2 = fig.add_subplot(2, 2, 2)
    # # ax_2.imshow((data_dict['image'][i]+2)/4)
    # #
    # #
    # # eyeball_radius = data_dict['eyeball'][i][0]
    # # eyebll_center_x = data_dict['eyeball'][i][1]
    # # eyebll_center_y = data_dict['eyeball'][i][2]
    # #
    # # circle1 = plt.Circle((eyebll_center_x, eyebll_center_y), eyeball_radius, color='r', fill=False)
    # # ax_2.add_patch(circle1)
    # # ax_2.scatter([eyebll_center_x], [eyebll_center_y], color='r')
    # # ax_2.plot([eyebll_center_x, eyebll_center_x+eyeball_radius * data_dict['gaze_vector'][i][0]],
    # #           [eyebll_center_y, eyebll_center_y+eyeball_radius * data_dict['gaze_vector'][i][1]], color='r', linewidth='2')
    # #
    # # ax_2.plot([eyebll_center_x, eyebll_center_x + eyeball_radius * gaze_vector_3D[i][0].detach().cpu().numpy()],
    # #           [eyebll_center_y, eyebll_center_y + eyeball_radius * gaze_vector_3D[i][1].detach().cpu().numpy()], color='green', linewidth='2')
    # # ax_2.text(0, 25, f'GazeGT:{data_dict["gaze_vector"][i]}')
    # # ax_2.text(0, 100, f'GazePD:{gaze_vector_3D[i]}')
    # # ax_2.plot([0, W, W, 0, 0], [0, 0, H, H, 0], color='black')
    # # ax_2.plot([0, W], [H/2, H/2], color='gray')
    # # ax_2.plot([W/2, W/2], [0, H], color='black')
    #
    # fig2 = plt.figure(figsize=plt.figaspect(1.))
    # ax_3 = fig2.add_subplot(1, 1, 1, projection='3d')
    # ax_3 = pr.plot_basis(ax_3, R=Rotation[i].detach().cpu().numpy(), p=T[i].detach().cpu().numpy(), s = 5)
    # ax_3 = pr.plot_basis(ax_3, s = 5)
    #
    # u, v = np.mgrid[0:2*np.pi:20j, 0:0.5*np.pi:20j]
    # x_sphere = L[i].detach().cpu().numpy()*np.cos(u)*np.sin(v)
    # y_sphere = L[i].detach().cpu().numpy()*np.sin(u)*np.sin(v)
    # z_sphere = L[i].detach().cpu().numpy()*np.cos(v)
    # x_sphere += T[i,0].detach().cpu().numpy().item()
    # y_sphere += T[i,1].detach().cpu().numpy().item()
    # z_sphere += T[i,2].detach().cpu().numpy().item()
    # ax_3.plot_wireframe(x_sphere, y_sphere, z_sphere, color="grey", alpha=0.3)
    #
    # ax_3.scatter(pupil_3D[i, :, 0].detach().cpu().numpy(), pupil_3D[i, :, 1].detach().cpu().numpy(), pupil_3D[i, :, 2].detach().cpu().numpy(), marker='x', color='yellow', alpha=0.8, s=1**2)
    # ax_3.scatter(iris_3D[i, :, 0].detach().cpu().numpy(), iris_3D[i, :, 1].detach().cpu().numpy(), iris_3D[i, :, 2].detach().cpu().numpy(), marker='o', color='green', alpha=0.5, s=1**2)
    # # ax_3.scatter(pupil_3D[i, edge_idx_pupil['left'],0].detach().cpu().numpy(), pupil_3D[i, edge_idx_pupil['left'],1].detach().cpu().numpy(), pupil_3D[i, edge_idx_pupil['left'],2].detach().cpu().numpy(), marker='x', color='red', s=3**2)
    # # ax_3.scatter(pupil_3D[i, edge_idx_pupil['right'],0].detach().cpu().numpy(), pupil_3D[i, edge_idx_pupil['right'],1].detach().cpu().numpy(), pupil_3D[i, edge_idx_pupil['right'],2].detach().cpu().numpy(), marker='x', color='red', s=3**2)
    # # ax_3.scatter(iris_3D[i, edge_idx_iris['left'],0].detach().cpu().numpy(), iris_3D[i, edge_idx_iris['left'],1].detach().cpu().numpy(), iris_3D[i, edge_idx_iris['left'],2].detach().cpu().numpy(), marker='o', color='red', s=3**2)
    # # ax_3.scatter(iris_3D[i, edge_idx_iris['right'],0].detach().cpu().numpy(), iris_3D[i, edge_idx_iris['right'],1].detach().cpu().numpy(), iris_3D[i, edge_idx_iris['right'],2].detach().cpu().numpy(), marker='o', color='red', s=3**2)
    #
    # # ax_2.plot([T[i,0].detach(), pupil_c_3D[i,0].detach()], [T[i,1].detach(), pupil_c_3D[i,1].detach()], [T[i,2].detach(), pupil_c_3D[i,2].detach()], color='k', linewidth=3)
    # # ax_2.plot([0.0, temp_gaze_vector_3D[i,0].detach()], [0.0, temp_gaze_vector_3D[i,1].detach()], [0.0, temp_gaze_vector_3D[i,2].detach()], color='k', linewidth=3)
    #
    # R1 = np.eye(3)[None,...]
    # R2 = np.eye(3)[None,...]
    # R3 = np.eye(3)[None,...]
    # R1[:, 0, 0] = -1.0; R1[:, 1, 1] = -1.0
    # R2[:, 0, 0] = -1.0; R2[:, 2, 2] = -1.0
    # R3[:, 1, 1] = -1.0; R3[:, 2, 2] = -1.0
    # T_t = T.detach().cpu().numpy()
    # gaze_gt = data_dict['gaze_vector'].clone().detach().cpu().numpy()
    # gaze_R1 = (R1 @ Rotation[:, :, 2][..., None].clone().detach().cpu().numpy())[..., 0]
    # gaze_Rt0 = torch.transpose(Rotation, 1, 2)[:, :, 2].clone().detach().cpu().numpy()
    # ln = L[i].detach().cpu().numpy().item() * 2.0
    # # ax_3.plot( [T_t[i,0] + 0.0, T_t[i,0] + ln*gaze_gt[i,0]],
    # #            [T_t[i,1] + 0.0, T_t[i,1] + ln*gaze_gt[i,1]],
    # #            [T_t[i,2] + 0.0, T_t[i,2] - ln*gaze_gt[i,2]],
    # #           color='k', linewidth=3)
    # # ax_3.plot( [T_t[i,0] + 0.0, T_t[i,0] + ln*gaze_vector_3D[i,0].detach().cpu().numpy()],
    # #            [T_t[i,1] + 0.0, T_t[i,1] + ln*gaze_vector_3D[i,1].detach().cpu().numpy()],
    # #            [T_t[i,2] + 0.0, T_t[i,2] + ln*gaze_vector_3D[i,2].detach().cpu().numpy()],
    # #           color='b', linewidth=3)
    # # ax_3.plot( [T_t[i,0] + 0.0, T_t[i,0] + ln*gaze_Rt0[i,0]],
    # #            [T_t[i,1] + 0.0, T_t[i,1] + ln*gaze_Rt0[i,1]],
    # #            [T_t[i,2] + 0.0, T_t[i,2] + ln*gaze_Rt0[i,2]],
    # #           color='b', linewidth=3)
    # # ax_3.plot( [T_t[i,0] + 0.0, T_t[i,0] + ln*gaze_R1[i,0]],
    # #            [T_t[i,1] + 0.0, T_t[i,1] + ln*gaze_R1[i,1]],
    # #            [T_t[i,2] + 0.0, T_t[i,2] + ln*gaze_R1[i,2]],
    # #           color='r', linewidth=3)
    #
    # ax_3.set_xlabel('X Label')
    # ax_3.set_ylabel('Y Label')
    # ax_3.set_zlabel('Z Label')
    # # 隐藏坐标轴
    # ax_3.set_axis_off()
    # ax_3.view_init(45, 180) # Initial viewing angle
    #
    #
    # plt.savefig('image_2.jpg', bbox_inches='tight')
    # # plt.show()

    # # TODO  TODO  TODO  TODO  TODO  TODO  TODO  TODO  TODO  TODO  TODO  TODO

    return out_dict, rend_dict

# 对模型预测的眼部参数进行缩放和约束，使它们在合理的物理范围内。
def scale_and_bound(T, R, r_pupil, r_iris, L, focal, args):
    '''
    对模型预测的眼部参数进行缩放和约束，使它们在合理的物理范围内。
    这个过程是根据经验数据和预定义的尺度和边界来实现的。
    具体来说，该方法调整和限制了眼球半径、瞳孔半径、虹膜半径、眼球中心位置、旋转角度和焦距，
    以确保这些参数在训练和推理过程中具有合理的物理意义。
    '''
    # 找到的数据来源于Google的研究结果
    # 虹膜半径 4.5-7.5
    # 瞳孔半径 1-4.5
    # 眼球半径 9-15
    # 平均眼球直径与眼球半径的比例约为0.4
    if args['scale_bound_eye'] == 'version_0':
        # 眼球半径
        L = 5.0 * L + 12.0

        # 瞳孔半径
        r_pupil = 4.0 * r_pupil + 4.0
        # 虹膜半径
        r_iris = 4.0 * r_iris + 5.5

        L_p = torch.sqrt(torch.abs(L ** 2 - r_iris ** 2))

        # 眼球中心
        scale = torch.tensor([[[17.0, 17.0, 60.0]]]).to(T.device)
        offset = torch.tensor([[[0.0, 0.0, 55.0]]]).to(T.device)
        T = scale * T + offset
        # 旋转角度范围在[-80.0,80.0]度之间
        # TODO：因为我们使用ZYX，所以将滚动的位置进行更改
        R = 80.0 * R * (torch.tensor([[0, 1, 1]])).to(R.device)

        # 焦距
        focal = 750.0 * focal + 750.0

    elif args['scale_bound_eye'] == 'version_0_1':
        # 眼球半径
        L = 3.5 * L + 12.5

        # 瞳孔半径
        r_pupil = 2.0 * r_pupil + 3.0
        # 虹膜半径
        r_iris = 2.0 * r_iris + 6.0

        L_p = torch.sqrt(torch.abs(L ** 2 - r_iris ** 2))

        # 眼球中心
        scale = torch.tensor([[[20.0, 20.0, 60.0]]]).to(T.device)
        offset = torch.tensor([[[0.0, 0.0, 60.0]]]).to(T.device)
        T = scale * T + offset
        # 旋转角度范围在[-80.0,80.0]度之间
        # TODO：因为我们使用ZYX，所以将滚动的位置进行更改
        R = 80.0 * R * (torch.tensor([[0, 1, 1]])).to(R.device)

        # 焦距
        focal = 800.0 * focal + 800.0

    elif args['scale_bound_eye'] == 'version_1':
        # 眼球半径
        L = 3.5 * L + 12.5

        # 虹膜半径范围
        r_i_min = L * 0.3
        r_i_max = L * 0.55
        r_iris = r_i_min + (r_i_max - r_i_min) * ((r_iris + 1.0) / 2.0)

        # 瞳孔半径范围
        r_p_min = 1.0
        r_p_max = r_iris
        r_pupil = r_p_min + (r_p_max - r_p_min) * ((r_pupil + 1.0) / 2.0)

        L_p = torch.sqrt(torch.abs(L ** 2 - r_iris ** 2))

        # 眼球中心
        scale = torch.tensor([[[20.0, 20.0, 60.0]]]).to(T.device)
        offset = torch.tensor([[[0.0, 0.0, 60.0]]]).to(T.device)
        T = scale * T + offset
        # 旋转角度范围在[-80.0,80.0]度之间
        # TODO：因为我们使用ZYX，所以将滚动的位置进行更改
        R = 80.0 * R * (torch.tensor([[0, 1, 1]])).to(R.device)

        # 焦距
        focal = 600.0 * focal + 800.0

    else:
        raise NotImplementedError

    return T, R, r_pupil, r_iris, L, L_p, focal

# 确保输入张量的形状与给定的参数和约束相匹配。
def tensor_shape_compatibility(T, R, r_pupil, r_iris, L, L_p, focal, args):
    '''确保输入张量的形状与给定的参数和约束相匹配。'''
    NUM_OF_BATCHES = args['batch_size']
    NUM_OF_FRAME = args['frames']
    iterations = NUM_OF_BATCHES * NUM_OF_FRAME

    # TODO: 以后删除
    T = T.expand(-1, NUM_OF_FRAME, -1)
    L = L.expand(-1, NUM_OF_FRAME)
    L_p = L_p.expand(-1, NUM_OF_FRAME)
    focal = focal.expand(-1, NUM_OF_FRAME, -1)
    r_iris = r_iris.expand(-1, NUM_OF_FRAME)

    # TODO: 以后删除（在移除for循环后）
    T = torch.flatten(T, end_dim=1)
    focal = torch.flatten(focal, end_dim=1)
    L = torch.flatten(L, end_dim=1)
    L_p = torch.flatten(L_p, end_dim=1)
    r_iris = torch.flatten(r_iris, end_dim=1)
    r_pupil = torch.flatten(r_pupil, end_dim=1)
    R = torch.flatten(R, end_dim=1)

    return T, R, r_pupil, r_iris, L, L_p, focal, iterations

def euler_to_rotation(R_deg):
    # 使用角度构建瞳孔的旋转矩阵
    # 将角度转换为弧度
    R_radian = R_deg * math.pi / 180
    # 修正坐标系以使虹膜和瞳孔在相机前面
    Rotation = transforms.euler_angles_to_matrix(R_radian, "ZYX")

    return Rotation


# 高效地处理张量形状的线性插值
def vectorized_linspace(start, end, steps):
    """
    torch.linspace 的向量化版本。
    输入:
    - start: 任意形状的张量
    - end: 与 start 相同形状的张量
    - steps: 整数
    返回:
    - out: 形状为 start.size() + (steps,) 的张量，使得
      out.select(-1, 0) == start，out.select(-1, -1) == end，
      并且 out 的其他元素在 start 和 end 之间线性插值。
    """
    assert start.size() == end.size()
    view_size = start.size() + (1,)
    w_size = (1,) * start.dim() + (steps,)
    out_size = start.size() + (steps,)

    start_w = torch.linspace(1, 0, steps=steps).to(end.device)
    start_w = start_w.view(w_size).expand(out_size)
    end_w = torch.linspace(0, 1, steps=steps).to(end.device)
    end_w = end_w.view(w_size).expand(out_size)

    start = start.contiguous().view(view_size).expand(out_size).to(end.device)
    end = end.contiguous().view(view_size).expand(out_size).to(end.device)

    out = start_w * start + end_w * end

    return out

# 提取虹膜和瞳孔的边缘轮廓索引
def template_edge_indexes(angles, N_radius):
    assert angles.ndim == 2
    assert isinstance(N_radius, int)

    B, N_angles = angles.shape

    edge_idx_iris = {}
    edge_idx_pupil = {}

    # 提取虹膜边缘角度位置    边缘和轮廓
    edge_iris_left = (angles[0]>=math.pi*0.9) * (angles[0]<=math.pi*1.1) # 左侧
    edge_iris_right = (angles[0]<=math.pi*0.1) + (angles[0]>=math.pi*1.9) # 右侧
    edge_idx_iris['left'] = (N_radius-1)*N_angles + torch.where(edge_iris_left)[0]
    edge_idx_iris['right'] = (N_radius-1)*N_angles + torch.where(edge_iris_right)[0]
    edge_idx_iris['sides'] = torch.cat((edge_idx_iris['left'], edge_idx_iris['right']))
    edge_idx_iris['outline'] = (N_radius-1)*N_angles + torch.arange(N_angles)

    # 提取瞳孔边缘角度位置（取所有圆上的点）
    edge_pupil_left = (angles[0]>math.pi*0.50) * (angles[0]<=math.pi*1.50)
    edge_pupil_right = (angles[0]<=math.pi*0.50) + (angles[0]>math.pi*1.50)
    edge_idx_pupil['left'] = (N_radius-1)*N_angles + torch.where(edge_pupil_left)[0]
    edge_idx_pupil['right'] = (N_radius-1)*N_angles + torch.where(edge_pupil_right)[0]
    edge_idx_pupil['sides'] = torch.cat((edge_idx_pupil['left'], edge_idx_pupil['right']))
    edge_idx_pupil['outline'] = (N_radius-1)*N_angles + torch.arange(N_angles)

    return edge_idx_iris, edge_idx_pupil

# 生成的瞳孔和虹膜的模板点云，以及边缘点索引
def template_generation(N_angles, N_radius, r_pupil, r_iris, L_p, args):
    assert isinstance(N_angles, int)
    assert isinstance(N_radius, int)
    assert r_pupil.ndim == r_iris.ndim == L_p.ndim == 1
    assert r_pupil.shape[0] == r_iris.shape[0] == L_p.shape[0]

    B = args['batch_size'] * args['frames']
    device = r_pupil.device

    # Angles [0, 2*pi) [B, N_ang]
    angles = np.linspace([0] * B, [1.9999 * math.pi] * B, N_angles,
                         axis=-1, dtype=np.float32)
    angles = torch.from_numpy(angles).to(device)


    # 获取边缘点索引
    edge_idx_iris, edge_idx_pupil = template_edge_indexes(angles, N_radius)

    # 瞳孔半径（圆）
    radius_pupil = vectorized_linspace(torch.zeros(B), r_pupil, N_radius)

    # 虹膜半径（环）
    radius_iris = vectorized_linspace(r_pupil, r_iris, N_radius)

    # 调整形状
    angles = rearrange(angles, 'b n -> b 1 n')

    radius_pupil = rearrange(radius_pupil, 'b n -> b n 1')
    radius_iris = rearrange(radius_iris, 'b n -> b n 1')

    # 瞳孔模板点云
    pupil_X = rearrange((radius_pupil * torch.cos(angles)), 'b n1 n2 -> b (n1 n2)')
    pupil_Y = rearrange((radius_pupil * torch.sin(angles)), 'b n1 n2 -> b (n1 n2)')
    pupil_Z = rearrange(L_p, 'b -> b 1') * torch.ones_like(pupil_Y).to(device)
    pupil_Z *= -1.0  # 以便在 [R|T] 之后在相机前面
    pupil_XYZ = torch.stack((pupil_X, pupil_Y, pupil_Z), dim=-1)

    # 虹膜模板点云
    iris_X = rearrange((radius_iris * torch.cos(angles)), 'b n1 n2 -> b (n1 n2)')
    iris_Y = rearrange((radius_iris * torch.sin(angles)), 'b n1 n2 -> b (n1 n2)')
    iris_Z = rearrange(L_p, 'b -> b 1') * torch.ones_like(pupil_Y).to(device)
    iris_Z *= -1.0  # 以便在 [R|T] 之后在相机前面
    iris_XYZ = torch.stack((iris_X, iris_Y, iris_Z), dim=-1)

    return pupil_XYZ, iris_XYZ, edge_idx_iris, edge_idx_pupil

# 相机坐标变换
def extrinsics_project(P, R, T, fx, fy, cx, cy):
    assert 1 == fx.ndim == fy.ndim
    assert T.ndim == 2
    assert R.ndim == 3
    assert P.ndim == 3
    assert isinstance(cx, float)
    assert isinstance(cy, float)

    # 外部变换
    P_3D = (R @ rearrange(P, 'b n d -> b d n')) + rearrange(T, 'b d -> b d 1')
    P_3D = rearrange(P_3D, 'b d n -> b n d')
    # 锥孔投影
    # * 1.0 以确保 P_3D 元素不会因为对 UV 的操作而改变
    # （否则 UV 将是 P_3D 的一个指针）
    UV = P_3D * 1.0

    UV[..., 0] *= fx.unsqueeze(-1)
    UV[..., 1] *= fy.unsqueeze(-1)
    UV = UV[..., :2] / (UV[..., 2:] + 1e-9)
    UV[..., 0] += cx
    UV[..., 1] += cy

    return P_3D, UV

# 对瞳孔和虹膜的三维点进行外部变换和投影
def project_templates_to_2D(P_pupil, P_iris, R, T, fx, fy, cx, cy):
    Pupil_3D, UV_pupil = extrinsics_project(P_pupil, R, T, fx, fy, cx, cy)
    Iris_3D, UV_iris = extrinsics_project(P_iris, R, T, fx, fy, cx, cy)

    return Pupil_3D, UV_pupil, Iris_3D, UV_iris

# 计算眼球中心在图像平面上的投影坐标
def eyeball_center(in_dict, H, W, args):
    # 计算眼球中心在图像平面上的投影坐标
    # 从输入字典中提取所需的参数
    T = in_dict['T']
    R = in_dict['R']
    r_pupil = in_dict['r_pupil']
    r_iris = in_dict['r_iris']
    L = in_dict['L']
    focal = in_dict['focal']

    cx = W / 2
    cy = H / 2
    fx = focal[:, 0]
    fy = focal[:, 0]

    B = args['batch_size'] * args['frames']
    device = r_pupil.device

    # 将欧拉角转换为旋转矩阵
    Rotation = euler_to_rotation(R)

    # 生成眼球圆模板
    N_angles = 100  # 角度数量
    angles = np.linspace([0] * B, [1.9999 * math.pi] * B, N_angles, axis=-1, dtype=np.float32)
    angles = torch.from_numpy(angles).to(device)

    # 重新排列参数的形状以便计算
    L = rearrange(L, 'b -> b 1')
    T_circle = rearrange(T, 'b n -> b 1 n')

    # 计算眼球圆的三维坐标
    eyeball_X = L * torch.cos(angles)
    eyeball_Y = L * torch.sin(angles)
    eyeball_Z = torch.zeros_like(eyeball_X)

    eyeball_XYZ = torch.stack((eyeball_X, eyeball_Y, eyeball_Z), dim=-1)

    # 将眼球圆坐标与眼睛中心坐标相加
    eyeball_XYZ += T_circle

    # 对眼球圆进行投影
    temp_UV = eyeball_XYZ
    temp_UV[..., 0] *= fx.unsqueeze(-1)
    temp_UV[..., 1] *= fy.unsqueeze(-1)

    eyeball_UV = temp_UV[..., :2] / temp_UV[..., 2:]
    eyeball_UV[..., 0] += cx
    eyeball_UV[..., 1] += cy

    return eyeball_UV[:, 1:, :]


def eye_model_visualize(T1, T2, T3, R1, R2, R3, r_pupil, r_iris, L, fx, fy):
    # Simulate NN output  # 模拟神经网络输出
    W = 640
    H = 480
    T = torch.tensor([[[T1, T2, T3]]])
    R = torch.tensor([[[R1, R2, R3]]])
    r_pupil = torch.tensor([[r_pupil]])
    r_iris = torch.tensor([[r_iris]])
    L = torch.tensor([[L]])
    focal = torch.tensor([[[fx, fy]]])

    args = {} 
    args['frames'] = 1
    args['batch_size'] = 1
    args['scale_bound_eye'] = "version_0"
    # Scale and bound predictions # 缩放和限制预测值
    T, R, r_pupil, r_iris, L, L_p, focal = scale_and_bound(T, R, r_pupil, r_iris, L, focal, args)

    # Tensor shape compatibility # 张量形状兼容性处理
    T, R, r_pupil, r_iris, L, L_p, focal, iterations = tensor_shape_compatibility(T, R, r_pupil, r_iris, L, L_p, focal, args)

    # Pupil and iris template # 瞳孔和虹膜模板
    N_angles = 1
    N_radius = 1
    pupil_XYZ, iris_XYZ, edge_idx_iris, edge_idx_pupil = template_generation(N_angles, 
                                                                             N_radius, 
                                                                             r_pupil, 
                                                                             r_iris, 
                                                                             L_p,
                                                                             args=args)
    pupil_XYZ_i = pupil_XYZ[0]
    iris_XYZ_i = iris_XYZ[0] 

    # 3D sphere # 3D 球体
    u, v = np.mgrid[0:2*np.pi:20j, 0:2*np.pi:20j]
    x_sphere = L[0]*np.cos(u)*np.sin(v)
    y_sphere = L[0]*np.sin(u)*np.sin(v)
    z_sphere = L[0]*np.cos(v)

    if True:
        #Plot everything before changing to camera coordinates  # 在转换为相机坐标之前绘制所有内容
        fig = plt.figure(figsize=plt.figaspect(2.))

        ax_0 = fig.add_subplot(2, 1, 1)
        ax_0.scatter(pupil_XYZ_i[:, 0].detach(), pupil_XYZ_i[:, 1].detach(), marker='x', color='black')
        ax_0.scatter(iris_XYZ_i[:, 0].detach(), iris_XYZ_i[:, 1].detach(), marker='o', color='green')
        ax_0.scatter(pupil_XYZ_i[edge_idx_pupil, 0].detach(), pupil_XYZ_i[edge_idx_pupil, 1].detach(), marker='x', color='red')
        ax_0.scatter(iris_XYZ_i[edge_idx_iris, 0].detach(), iris_XYZ_i[edge_idx_iris, 1].detach(), marker='o', color='red')
        ax_0.set_xlabel('X Label')
        ax_0.set_ylabel('Y Label')

        ax_1 = fig.add_subplot(2, 1, 2, projection='3d')
        ax_1 = pr.plot_basis(ax_1, s = 5)
        ax_1.plot_wireframe(x_sphere, y_sphere, z_sphere, color="grey", alpha=0.1)
        ax_1.scatter(pupil_XYZ_i[:,0].detach(), pupil_XYZ_i[:,1].detach(), pupil_XYZ_i[:,2].detach(), marker='x', color='black', alpha=0.6)
        ax_1.scatter(iris_XYZ_i[:,0].detach(), iris_XYZ_i[:,1].detach(), iris_XYZ_i[:,2].detach(), marker='o', color='green', alpha=0.1)
        ax_1.scatter(pupil_XYZ_i[edge_idx_pupil,0].detach(), pupil_XYZ_i[edge_idx_pupil,1].detach(), pupil_XYZ_i[edge_idx_pupil,2].detach(), marker='x', color='red')
        ax_1.scatter(iris_XYZ_i[edge_idx_iris,0].detach(), iris_XYZ_i[edge_idx_iris,1].detach(), iris_XYZ_i[edge_idx_iris,2].detach(), marker='o', color='red')
        ax_1.set_xlabel('X Label')
        ax_1.set_ylabel('Y Label')
        ax_1.set_zlabel('Z Label')
        plt.show()

    # Euler angles [pip install ipympldeg] to rotation matrix # 欧拉角转换为旋转矩阵
    Rotation = euler_to_rotation(R)

    # Project 3D template to 2D image frame  # 将 3D 模板投影到 2D 图像帧上
    pupil_point3D, pupil_UV, iris_point3D, iris_UV = project_templates_to_2D(
                        pupil_XYZ, iris_XYZ, Rotation, T, 
                        focal[:, 0], focal[:, 0], W/2, H/2)
    pupil_UV_i = pupil_UV[0]
    iris_UV_i = iris_UV[0]
    pupil_point3D = pupil_point3D[0]
    iris_point3D = iris_point3D[0]

    #estimate the gaze point # 估计凝视点    # 初始凝视点
    #initial_gaze = torch.tensor([0, 0, -1], dtype=torch.float)
    #print(initial_gaze)
    #pd_gaze_vec = Rotation @ initial_gaze
    #pd_gaze_vec = pd_gaze_vec * 100
    #print('after rotation')
    #print(pd_gaze_vec)
                    
    #pd_gaze_norm = torch.linalg.norm(pd_gaze_vec, dim=1, keepdims=True)
    #pd_gaze_vec_normalized = pd_gaze_vec / pd_gaze_norm
    
    #Plot everything before changing in camera coordinates
    #fig = plt.figure()
    # 绘制相机坐标系之前的所有内容
    fig = plt.figure(figsize=plt.figaspect(2.))
    ax_0 = fig.add_subplot(2, 1, 1)

    ax_0.scatter(pupil_UV_i[:,0].detach(), pupil_UV_i[:,1].detach(), marker='x', color='black', alpha=0.8, s=1**2)
    ax_0.scatter(iris_UV_i[:,0].detach(), iris_UV_i[:,1].detach(), marker='o', color='green', alpha=0.5, s=1**2)
    ax_0.scatter(pupil_UV_i[edge_idx_pupil['left'],0].detach(), pupil_UV_i[edge_idx_pupil['left'],1].detach(), marker='x', color='red', alpha=0.8, s=3**2)
    ax_0.scatter(pupil_UV_i[edge_idx_pupil['right'],0].detach(), pupil_UV_i[edge_idx_pupil['right'],1].detach(), marker='x', color='red', alpha=0.8, s=3**2)
    ax_0.scatter(iris_UV_i[edge_idx_iris['left'],0].detach(), iris_UV_i[edge_idx_iris['left'],1].detach(), marker='o', color='red', alpha=0.5, s=3**2)
    ax_0.scatter(iris_UV_i[edge_idx_iris['right'],0].detach(), iris_UV_i[edge_idx_iris['right'],1].detach(), marker='o', color='red', alpha=0.5, s=3**2)
    ax_0.plot([0, W, W, 0, 0], [0, 0, H, H, 0], color='black')
    ax_0.plot([0, W], [H/2, H/2], color='gray')
    ax_0.plot([W/2, W/2], [0, H], color='black')
    ax_0.set_xlabel('X Label')
    ax_0.set_ylabel('Y Label')

    ax_1 = fig.add_subplot(2, 1, 2, projection='3d')
    ax_1 = pr.plot_basis(ax_1, s = 5)
    x_sphere += T[0,0].item()
    y_sphere += T[0,1].item()
    z_sphere += T[0,2].item()
    ax_1.plot_wireframe(x_sphere, y_sphere, z_sphere, color="grey", alpha=0.1)

    ax_1.scatter(pupil_point3D[:, 0].detach(), pupil_point3D[:, 1].detach(), pupil_point3D[:, 2].detach(), marker='x', color='black', alpha=0.8, s=1**2)
    ax_1.scatter(iris_point3D[:, 0].detach(), iris_point3D[:, 1].detach(), iris_point3D[:, 2].detach(), marker='o', color='green', alpha=0.5, s=1**2)
    ax_1.scatter(pupil_point3D[edge_idx_pupil['left'],0].detach(), pupil_point3D[edge_idx_pupil['left'],1].detach(), pupil_point3D[edge_idx_pupil['left'],2].detach(), marker='x', color='red', s=3**2)
    ax_1.scatter(pupil_point3D[edge_idx_pupil['right'],0].detach(), pupil_point3D[edge_idx_pupil['right'],1].detach(), pupil_point3D[edge_idx_pupil['right'],2].detach(), marker='x', color='red', s=3**2)
    ax_1.scatter(iris_point3D[edge_idx_iris['left'],0].detach(), iris_point3D[edge_idx_iris['left'],1].detach(), iris_point3D[edge_idx_iris['left'],2].detach(), marker='o', color='red', s=3**2)
    ax_1.scatter(iris_point3D[edge_idx_iris['right'],0].detach(), iris_point3D[edge_idx_iris['right'],1].detach(), iris_point3D[edge_idx_iris['right'],2].detach(), marker='o', color='red', s=3**2)
    ax_1.set_xlabel('X Label')
    ax_1.set_ylabel('Y Label')
    ax_1.set_zlabel('Z Label')
    ax_1.view_init(-45, -90) # Initial viewing angle

    # print(f'Rotation=[{R[0, 0]}, {R[0, 1]}, {R[0, 2]}]')
    # print(f'Translation=[{T[0, 0]}, {T[0, 1]}, {T[0, 2]}]')
    # print(f'r_pupil=[{r_pupil[0]}]')
    # print(f'r_iris=[{r_iris[0]}]')
    # print(f'L=[{L[0]}]')
    # print(f'focal=[{focal[0, 0]}, {focal[0, 1]}]')

    plt.savefig('eye_model_vis222.jpg',bbox_inches='tight')
    plt.show()


if __name__ == "__main__":
    # #specify the param of the eye model for the training 
    # gt_T = torch.tensor([[0.0, 0.0, 58.0],]).to(device)
    # gt_R = torch.tensor([[-3.5, -13.0, 0.0],]).to(device)
    # gt_r_pupil = torch.tensor([[1.1],]).to(device)
    # gt_r_iris = torch.tensor([[6.5],]).to(device)
    # gt_L = torch.tensor([[10.1],]).to(device)
    # gt_focal = torch.tensor([[370.0, 600.0],]).to(device)
    #
    T1 = 0.; T2 = 0.; T3 = 0.
    R1 = 0.; R2 = 0.; R3 = 0.
    r_pupil = 0.
    r_iris = 0.
    L = 0.
    fx = 0.; fy=0.
    args = {}
    args['batch_size'] = 1
    args['frames'] = 1
    args['scale_bound_eye'] = 'version_0'
    args['temp_n_angles'] = 100
    args['temp_n_radius'] = 50
    eye_model_visualize(T1, T2, T3, R1, R2, R3, r_pupil, r_iris, L, fx, fy)
    T1 = widgets.FloatSlider(value=0.0, min=-1.0, max=1.0, step=0.05)
    T2 = widgets.FloatSlider(value=0.0, min=-1.0, max=1.0, step=0.05)
    T3 = widgets.FloatSlider(value=0.0, min=-1.0, max=1.0, step=0.05)
    R1 = widgets.FloatSlider(value=0.0, min=-1.0, max=1.0, step=0.05)
    R2 = widgets.FloatSlider(value=0.0, min=-1.0, max=1.0, step=0.05)
    R3 = widgets.FloatSlider(value=0.0, min=-1.0, max=1.0, step=0.05)
    r_pupil = widgets.FloatSlider(value=0.0, min=-1.0, max=1.0, step=0.05)
    r_iris = widgets.FloatSlider(value=0.0, min=-1.0, max=1.0, step=0.05)
    L = widgets.FloatSlider(value=0.0, min=-1.0, max=1.0, step=0.05)
    fx = widgets.FloatSlider(value=0.0, min=-1.0, max=1.0, step=0.05)
    fy = widgets.FloatSlider(value=0.0, min=-1.0, max=1.0, step=0.05)

    interact(eye_model_visualize, T1=T1, T2=T2, T3=T3, R1=R1, R2=R2, R3=R3,
             r_pupil=r_pupil, r_iris=r_iris, L=L, fx=fx, fy=fy)

    # device = 'cpu'
    #
    # img = cv2.imread(r"D:\Xiao\DataSet\TaiicEyes\train\images\1.bmp", 0)
    # label = np.load(r"D:\Xiao\DataSet\TaiicEyes\train\labels\1.npy")
    #
    # r = np.where(label)[0]
    # c = int(0.5*(np.max(r) + np.min(r)))
    # top, bot = (0, c+150-(c-150)) if c-150<0 else (c-150, c+150)
    #
    # img = img[top:bot, :]
    # label = label[top:bot, :]
    #
    # img = cv2.resize(img, (640, 480), interpolation=cv2.INTER_LANCZOS4)
    # label = cv2.resize(label, (640, 480), interpolation=cv2.INTER_LANCZOS4)
    #
    # gt_dict = {}
    # label_tensor = torch.tensor(label).unsqueeze(0).unsqueeze(0)
    # gt_dict['mask'] = label_tensor.to(device)
    #
    # T = torch.tensor([[[0.0, 0.0, -0.5]]]).cpu()
    # R = torch.tensor([[[0.0, -0.15, -0.08]]]).cpu()
    # r_pupil = torch.tensor([[-0.8]]).cpu()
    # r_iris = torch.tensor([[0.5]]).cpu()
    # L = torch.tensor([[1.0]]).cpu()
    # focal = torch.tensor([[[0.0, 0.2]]]).cpu()
    #
    # eyeball_param = {
    #     'T': T,
    #     'R': R,
    #     'r_pupil': r_pupil,
    #     'r_iris': r_iris,
    #     'L': L,
    #     'focal': focal
    # }
    #
    # W = 640
    # H = 480
    #
    # plt.clf()
    #
    # rend, loss = render_semantics(eyeball_param, H, W, args = args)