from einops import rearrange

from helperfunctions import DataProcessFuncs as funcs
import numpy as np
import time
from cv2 import Rodrigues
import sklearn.manifold
from scipy.optimize import least_squares

from scripts import angular_error


def InferenceGPM(p, pts, debug=False):
    # pts 减去参数 p 的第 10~12 位（偏移量），将点云中心平移到原点附近
    pts = pts - p[10:13]

    # 对每个点进行归一化处理（L2 范数），使点投影到单位球面上
    pts = pts / np.linalg.norm(pts, ord=2, axis=1, keepdims=True)

    # 根据 p 的前三个参数计算旋转矩阵 R，Rodrigues 将旋转向量转为旋转矩阵
    R, _ = Rodrigues(p[:3])

    # 将归一化的点旋转，调整姿态
    pts_rotated = (R @ pts.reshape((pts.shape[0], 3, 1))).reshape(pts.shape)

    # 计算 yaw 角（偏航角），通过点的旋转坐标计算（x 和 z 分量）
    yaw = np.arctan2(-pts_rotated[:, 0], -pts_rotated[:, 2])

    # 对 yaw 角做线性变换，p[6] 是缩放系数，p[7] 是偏置
    yaw = p[6] * yaw + p[7]

    # 计算 pitch 角（俯仰角），通过点的旋转坐标计算（y 分量）
    pitch = np.arcsin(-pts_rotated[:, 1])

    # 对 pitch 角做线性变换，p[8] 是缩放系数，p[9] 是偏置
    pitch = p[8] * pitch + p[9]

    # 把 yaw 和 pitch 合并成数组，调用 funcs.gazeTo3d_array 转换为三维注视方向向量并返回
    return funcs.gazeTo3d_array(np.hstack((yaw.reshape(len(yaw), 1), pitch.reshape(len(pitch), 1))))


def SolveGPM(p, pts, label):
    # 调用 InferenceGPM 获取预测的注视方向，再计算与真实标签的角度误差（批量）
    return funcs.angular_batch(InferenceGPM(p, pts), label)


class GeodesicProjection():
    def __init__(self, logger, pts, label, params=None):
        if params is None:
            # label 转换成三维注视向量
            # self.label = funcs.gazeTo3d_array(label)
            self.label = label

            # 初始化参数 p0：旋转向量（R0_x, R0_y, R0_z, R1_x, R1_y, R1_z），
            # 线性缩放系数和平移（k0, b0, k1, b1），以及偏移量（O_x, O_y, O_z）
            self.p0 = [0.1, 0.1,  0.1,  0.1,   0.1, 0.1,   1,  0, 1,   0,   0,   0,   4.48]

            # 传入的点云数据
            self.pts = pts

            # 利用最小二乘拟合，求解参数 p，使得预测结果与真实标签误差最小
            ls_result = least_squares(
                SolveGPM,
                self.p0,
                args=(self.pts, self.label),
                bounds=(
                    [-4, -4, -4, -4, -4, -4,  -np.inf, -np.pi, -np.inf, -np.pi, -np.inf, -np.inf, -np.inf],
                    [4, 4, 4, 4, 4, 4, np.inf, np.pi, np.inf, np.pi, np.inf, np.inf, np.inf]
                )
            )
            self.result = ls_result.x  # 优化得到的参数

            origin = self.result[10:13]  # 偏移量（球心）
            Rs = np.sqrt(np.sum((pts - origin) ** 2, axis=1))  # 计算每个点到球心的距离（半径）

            # 将半径均值附加到参数末尾，表示球半径的估计值
            self.result = np.hstack((self.result, np.mean(Rs)))
        else:
            # 直接使用传入的参数 params，不进行拟合
            self.pts = pts
            self.result = params
            origin = self.result[10:13]
            Rs = np.sqrt(np.sum((pts - origin) ** 2, axis=1))

        # 这里注释掉的 print 可用于打印拟合误差   打印每个数据点到球心的距离（半径）
        logger.write(f'GeodesicProjection: Sphere Error - {np.mean(abs(Rs - self.result[-1])/self.result[-1] * 100):.2f}%, {np.mean(np.sqrt(Rs))}')


    def GetResult(self):
        # 返回拟合参数，后面两个返回 None，可能预留给未来扩展
        return self.result, None, None

    def __call__(self, pts):
        # 使对象本身可调用，给定点集返回当前拟合参数的预测注视向量
        return InferenceGPM(self.result, pts, debug=False)


def ISOMap(data, dim=3, n_neighbors=300, fitter=None, verbose=False):
    begin = time.time()
    data = rearrange(data, 'B N d -> (B N) d')

    if n_neighbors is None:
        # 如果邻居数未设置，默认设置为样本数的 40%
        n_neighbors = int(data.shape[0] * 0.4)

    if fitter is None:
        if verbose:
            print(f"[State info ] Start New Isomap ...")
        # 新建 sklearn 的 Isomap 拟合器，设定邻居数和目标维度
        ISO_fitter = sklearn.manifold.Isomap(n_neighbors=n_neighbors, n_components=dim)
        # 对数据进行拟合，计算低维表示
        fitter = ISO_fitter.fit(data)
    else:
        if verbose:
            print(f"[State info ] ISOMap from Existing Param...")

    # 利用拟合器转换原始数据到低维空间
    PGF = fitter.transform(data)

    if verbose:
        print(f"[State info ] Successful! ISOMap complete in {(time.time() - begin):.2f}s")

    return PGF, fitter


def FitGaze(logger, PGF, gaze_label, param=None, verbose=True):
    if param is None:
        if verbose:
            print(f'[State info ] Start a new Geodesic Projection....')

        # 创建 GeodesicProjection 对象，拟合参数
        model = GeodesicProjection(logger, PGF, gaze_label, param)
    else:
        if verbose:
            print(f'[State info ] Geodesic Projection from Existing Parameter!')
        # 直接使用已有参数
        model = GeodesicProjection(logger, PGF, None, param)

    # 获取拟合参数
    fitting_param, _, _ = model.GetResult()

    # 球心坐标（偏移量）和半径
    sphere = fitting_param[10:]

    # 用拟合参数对数据做推理，得到预测注视向量
    preds = model(PGF)

    if verbose:
        # 输出拟合误差（角度误差，单位度）和球半径信息
        # logger.write(f'[State info ] Fitting Error: {np.mean(funcs.angular_batch(preds, gaze_label)) * 180 / np.pi}')
        logger.write(f'[State info ] Fitting Error: {np.mean(angular_error(preds, gaze_label))}')
        logger.write(f'sphere: {sphere}')

    # 返回拟合参数，预测结果，拟合模型
    return fitting_param, preds, model
