"""
This started as a copy of https://bitbucket.org/RSKothari/multiset_gaze/src/master/
with additional changes and modifications to adjust it to our implementation.

Copyright (c) 2021 Rakshit Kothari, Aayush Chaudhary, Reynold Bailey, Jeff Pelz,
and Gabriel Diaz
"""
from datetime import datetime

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import scipy
import copy
import math
import cv2
import os
import torch
#import wandb

from PIL import Image
from skimage import draw
from itertools import chain
from skimage.transform import rescale
from skimage.segmentation.boundaries import find_boundaries
from scipy.ndimage import distance_transform_edt as distance

# from einops import rearrange
from Visualitation_TEyeD.gaze_estimation import generate_gaze_gt

EPS = 1e-40

# Helper classes
class my_ellipse():
    def __init__(self, param):
        '''
        接受参数化形式
        用于处理椭圆的参数化表示和矩阵表示之间的转换。它能够接受椭圆的参数化形式或矩阵形式，进行参数转换，并且可以生成和验证椭圆上的点。
        '''
        self.EPS = 1e-3  # 设置误差阈值
        if param is not list:
            self.param = param
            self.mat = self.param2mat(self.param)
            self.quad = self.mat2quad(self.mat)
        else:
            if param:
                raise Exception('my_ellipse 只接受 numpy 数组')  # 如果param是非空列表，抛出异常

    def param2mat(self, param):
        cx, cy, a, b, theta = tuple(param)  # 解包参数
        H_rot = rotation_2d(-theta)  # 计算旋转矩阵
        H_trans = trans_2d(-cx, -cy)  # 计算平移矩阵

        A, B = 1/a**2, 1/b**2  # 计算半轴倒数的平方
        Q = np.array([[A, 0, 0], [0, B, 0], [0, 0, -1]])  # 构造Q矩阵
        mat = H_trans.T @ H_rot.T @ Q @ H_rot @ H_trans  # 计算总矩阵
        return mat

    def mat2quad(self, mat):
        assert np.sum(np.abs(mat.T - mat)) <= self.EPS, '二次型形式不正确'  # 检查矩阵是否对称
        a, b, c, d, e, f = mat[0,0], 2*mat[0, 1], mat[1,1], 2*mat[0, 2], 2*mat[1, 2], mat[-1, -1]
        return np.array([a, b, c, d, e, f])  # 返回二次型系数

    def quad2param(self, quad):
        mat = self.quad2mat(quad)  # 将二次型转换为矩阵
        param = self.mat2param(mat)  # 将矩阵转换为参数
        return param

    def quad2mat(self, quad):
        a, b, c, d, e, f = tuple(quad)  # 解包二次型系数
        mat = np.array([[a, b/2, d/2], [b/2, c, e/2], [d/2, e/2, f]])  # 构造矩阵
        return mat

    def mat2param(self, mat):
        assert np.sum(np.abs(mat.T - mat)) <= self.EPS, '二次型形式不正确'  # 检查矩阵是否对称
        theta = self.recover_theta(mat)  # 恢复旋转角度
        tx, ty = self.recover_C(mat)  # 恢复平移
        H_rot = rotation_2d(theta)  # 计算旋转矩阵
        H_trans = trans_2d(tx, ty)  # 计算平移矩阵
        mat_norm = H_rot.T @ H_trans.T @ mat @ H_trans @ H_rot  # 归一化矩阵
        major_axis = np.sqrt(1/mat_norm[0,0])  # 计算长轴
        minor_axis = np.sqrt(1/mat_norm[1,1])  # 计算短轴
        area = np.pi*major_axis*minor_axis  # 计算面积
        return np.array([tx, ty, major_axis, minor_axis, theta, area])

    def phi2param(self, xm, ym):
        '''
        根据phi值计算椭圆参数

        参数
        ----------
        Phi : np.array [5, ]
            有关Phi值的信息，请参考ElliFit。
        xm : int
        ym : int

        返回
        -------
        param : np.array [5, ].
            椭圆参数, [cx, cy, a, b, theta]

        '''
        try:
            x0=(self.Phi[2]-self.Phi[3]*self.Phi[1])/((self.Phi[0])-(self.Phi[1])**2)
            y0=(self.Phi[0]*self.Phi[3]-self.Phi[2]*self.Phi[1])/((self.Phi[0])-(self.Phi[1])**2)
            term2=np.sqrt(((1-self.Phi[0])**2+4*(self.Phi[1])**2))
            term3=(self.Phi[4]+(y0)**2+(x0**2)*self.Phi[0]+2*self.Phi[1])
            term1=1+self.Phi[0]
            print(term1, term2, term3)
            b=(np.sqrt(2*term3/(term1+term2)))
            a=(np.sqrt(2*term3/(term1-term2)))
            alpha=0.5*np.arctan2(2*self.Phi[1],1-self.Phi[0])
            model = [x0+xm, y0+ym, a, b, -alpha]
        except:
            print('生成的不适当模型')
            model = [np.nan, np.nan, np.nan, np.nan, np.nan]
        if np.all(np.isreal(model)) and np.all(~np.isnan(model)) and np.all(~np.isinf(model)):
            model = model
        else:
            model = [-1, -1, -1, -1, -1]
        return model

    def recover_theta(self, mat):
        a, b, c, d, e, f = tuple(self.mat2quad(mat))
        if abs(b)<=EPS and a<=c:
            theta = 0.0
        elif abs(b)<=EPS and a>c:
            theta=np.pi/2
        elif abs(b)>EPS and a<=c:
            theta=0.5*np.arctan2(b, (a-c))
        elif abs(b)>EPS and a>c:
            theta = 0.5*np.arctan2(b, (a-c))
        else:
            print('未知条件')
        return theta

    def recover_C(self, mat):
        a, b, c, d, e, f = tuple(self.mat2quad(mat))
        tx = (2*c*d - b*e)/(b**2 - 4*a*c)
        ty = (2*a*e - b*d)/(b**2 - 4*a*c)
        return (tx, ty)

    def transform(self, H):
        '''
        给定变换矩阵 H，修改椭圆
        '''
        mat_trans = np.linalg.inv(H.T) @ self.mat @ np.linalg.inv(H)
        return self.mat2param(mat_trans), self.mat2quad(mat_trans), mat_trans

    def recover_Phi(self):
        '''
        生成 Phi
        '''
        x, y = self.generatePoints(50, 'random')
        data_pts = np.stack([x, y], axis=1)
        ellipseFit = ElliFit(**{'data':data_pts})
        return ellipseFit.Phi

    def verify(self, pts):
        '''
        给定一个 Nx2 点数组，验证椭圆模型
        '''
        N = pts.shape[0]
        pts = np.concatenate([pts, np.ones((N, 1))], axis=1)
        err = 0.0
        for i in range(0, N):
            err+=pts[i, :]@self.mat@pts[i, :].T  # 注意这里的转置是无关的
        return np.inf if (N==0) else err/N

    def generatePoints(self, N, mode):
        '''
        生成椭圆周围的 8 个点。模式决定点之间的一致性。
        mode: str
        'equiAngle' - 周围点的角度 [0:45:360)
        'equiSlope' - 周围点的切线斜率 [-1:0.5:1)
        'random' - 随机生成 N 个点在椭圆周围
        '''

        a = self.param[2]
        b = self.param[3]

        alpha = (a*np.sin(self.param[-1]))**2 + (b*np.cos(self.param[-1]))**2
        beta = (a*np.cos(self.param[-1]))**2 + (b*np.sin(self.param[-1]))**2
        gamma = (a**2 - b**2)*np.sin(2*self.param[-1])

        if mode == 'equiSlope':
            slope_list = [1e-6, 1, 1000, -1]
            K_fun = lambda m_i:  (m_i*gamma + 2*alpha)/(2*beta*m_i + gamma)

            x_2 = [((a*b)**2)/(alpha + beta*K_fun(m)**2 - gamma*K_fun(m)) for m in slope_list]

            x = [(+np.sqrt(val), -np.sqrt(val)) for val in x_2]
            y = []
            for i, m in enumerate(slope_list):
                y1 = -x[i][0]*K_fun(m)
                y2 = -x[i][1]*K_fun(m)
                y.append((y1, y2))
            y_r = np.array(list(chain(*y))) + self.param[1]
            x_r = np.array(list(chain(*x))) + self.param[0]

        if mode == 'equiAngle':

            T = 0.5*np.pi*np.array([-1.5, -1, -0.5, 0, 0.5, 1, 1.5, 2])
            N = len(T)
            x = self.param[2]*np.cos(T)
            y = self.param[3]*np.sin(T)
            H_rot = rotation_2d(self.param[-1])
            X1 = H_rot.dot(np.stack([x, y, np.ones(N, )], axis=0))

            x_r = X1[0, :] + self.param[0]
            y_r = X1[1, :] + self.param[1]

        elif mode == 'random':
            T = 2*np.pi*(np.random.rand(N, ) - 0.5)
            x = self.param[2]*np.cos(T)
            y = self.param[3]*np.sin(T)
            H_rot = rotation_2d(self.param[-1])
            X1 = H_rot.dot(np.stack([x, y, np.ones(N, )], axis=0))
            x_r = X1[0, :] + self.param[0]
            y_r = X1[1, :] + self.param[1]

        else:
            print('模式未定义')

        return x_r, y_r

def create_experiment_folder_tree(repo_root,
                                  path_exp_records,
                                  exp_name,
                                  is_test=False,
                                  create_tree=True):
    """
    创建实验文件夹结构的函数。
    参数:
    repo_root (str): 仓库的根目录。
    path_exp_records (str): 存放实验记录的目录路径。
    exp_name (str): 实验名称。
    is_test (bool): 是否是测试模式。默认为False。
    create_tree (bool): 是否创建文件夹结构。默认为True。
    返回:
    tuple: 包含路径字典和实验名称字符串的元组。
    """
    if is_test:
        # 如果是测试模式，使用提供的实验名称。
        exp_name_str = exp_name
    else:
        # 否则，生成带有日期时间和随机字符串的实验名称。
        now = datetime.now()
        date_time_str = now.strftime('%y_%m_%d_%H_%M_%S')
        exp_name_str = exp_name + '_' + date_time_str
    # 构建实验文件夹的完整路径。
    path_exp = os.path.join(path_exp_records, exp_name_str)
    # 创建路径字典以存放结果、图像、日志和源代码文件夹的路径。
    path_dict = {}
    for ele in ['results', 'figures', 'logs', 'gpm']:
        path_dict[ele] = os.path.join(path_exp, ele)
        os.makedirs(path_dict[ele], exist_ok=True)  # 创建目录，如果目录存在则忽略。
    # 添加实验文件夹的路径到字典中。
    path_dict['exp'] = path_exp
    # TODO: 不包括隐藏文件，然后重新启用
    # if (not is_test) and create_tree:
    #     # 如果不是测试模式并且需要创建树结构，
    #     # 从仓库根目录复制文件到实验文件夹的'src'目录中。
    #     copy_tree(repo_root, os.path.join(path_exp, 'src'))

    return path_dict, exp_name_str  # 返回路径字典和实验名称字符串。

class ElliFit():
    # 拟合二维数据点（如图像中的点）到一个椭圆模型上。它提供了基于加权和非加权最小二乘法的椭圆拟合，并计算拟合的误差。
    def __init__(self, **kwargs):
        self.data = np.array([])  # 存储输入的二维数据点，形状为Nx2
        self.W = np.array([])     # 存储权重矩阵
        self.Phi = []             # 存储椭圆参数
        self.pts_lim = 6 * 2      # 数据点数目限制，必须大于12
        for k, v in kwargs.items():
            setattr(self, k, v)  # 将传入的关键字参数设置为类的属性
        if np.size(self.W):
            self.weighted = True  # 判断是否使用加权最小二乘法
        else:
            self.weighted = False  # 使用非加权最小二乘法
        if np.size(self.data) > self.pts_lim:
            self.model = self.fit()  # 如果数据点足够多，进行椭圆拟合
            self.error = np.mean(self.fit_error(self.data))  # 计算拟合误差的平均值
        else:
            self.model = [-1, -1, -1, -1, -1]  # 数据点不够时，返回无效的模型参数
            self.Phi = [-1, -1, -1, -1, -1]
            self.error = np.inf  # 将误差设置为无穷大

    def fit(self):
        # 从论文ElliFit实现的椭圆拟合代码
        xm = np.mean(self.data[:, 0])  # 计算数据点x坐标的平均值
        ym = np.mean(self.data[:, 1])  # 计算数据点y坐标的平均值
        x = self.data[:, 0] - xm  # 平移数据点，使其中心在原点
        y = self.data[:, 1] - ym
        X = np.stack([x**2, 2*x*y, -2*x, -2*y, -np.ones((np.size(x), ))], axis=1)  # 构造矩阵X
        Y = -y**2  # 构造矩阵Y
        if self.weighted:
            self.Phi = np.linalg.inv(
                X.T.dot(np.diag(self.W)).dot(X)  # 加权最小二乘法
            ).dot(
                X.T.dot(np.diag(self.W)).dot(Y)
            )
        else:
            try:
                self.Phi = np.matmul(np.linalg.inv(np.matmul(X.T, X)), np.matmul(X.T, Y))  # 非加权最小二乘法
            except:
                self.Phi = -1 * np.ones(5, )  # 发生错误时返回无效的椭圆参数
        try:
            # 从Phi计算椭圆参数
            x0 = (self.Phi[2] - self.Phi[3] * self.Phi[1]) / ((self.Phi[0]) - (self.Phi[1])**2)
            y0 = (self.Phi[0] * self.Phi[3] - self.Phi[2] * self.Phi[1]) / ((self.Phi[0]) - (self.Phi[1])**2)
            term2 = np.sqrt(((1 - self.Phi[0])**2 + 4 * (self.Phi[1])**2))
            term3 = (self.Phi[4] + (y0)**2 + (x0**2) * self.Phi[0] + 2 * self.Phi[1])
            term1 = 1 + self.Phi[0]
            b = (np.sqrt(2 * term3 / (term1 + term2)))
            a = (np.sqrt(2 * term3 / (term1 - term2)))
            alpha = 0.5 * np.arctan2(2 * self.Phi[1], 1 - self.Phi[0])
            model = [x0 + xm, y0 + ym, a, b, -alpha]  # 将椭圆参数平移回原来的位置
        except:
            print('生成的不适当模型')
            model = [np.nan, np.nan, np.nan, np.nan, np.nan]  # 发生错误时返回无效的模型参数
        if np.all(np.isreal(model)) and np.all(~np.isnan(model)) and np.all(~np.isinf(model)):
            model = model  # 确保模型参数有效
        else:
            model = [-1, -1, -1, -1, -1]  # 无效的模型参数
        return model

    def fit_error(self, data):
        # 通用函数来找到残差
        # model: xc, yc, a, b, theta
        term1 = (data[:, 0] - self.model[0]) * np.cos(self.model[-1])
        term2 = (data[:, 1] - self.model[1]) * np.sin(self.model[-1])
        term3 = (data[:, 0] - self.model[0]) * np.sin(self.model[-1])
        term4 = (data[:, 1] - self.model[1]) * np.cos(self.model[-1])
        res = (1 / self.model[2]**2) * (term1 - term2)**2 + \
              (1 / self.model[3]**2) * (term3 + term4)**2 - 1
        return np.abs(res)  # 返回残差的绝对值


class ransac():
    # 实现了RANSAC（随机抽样一致性）算法，该算法通过随机抽样和拟合模型来识别和排除离群点（异常值），以找到最优的模型参数。具体来说，ransac类用于在存在噪声和离群点的数据集中拟合一个最优的椭圆模型。
    def __init__(self, data, model, n_min, mxIter, Thres, n_good):
        self.data = data  # 输入数据集
        self.num_pts = data.shape[0]  # 数据点的数量
        self.model = model  # 用于拟合数据的模型类
        self.n_min = n_min  # 每次随机抽样的最小点数
        self.D = n_good if n_min < n_good else n_min  # 设定内点的阈值，确保至少有 n_min 个内点
        self.K = mxIter  # 最大迭代次数
        self.T = Thres  # 残差阈值，用于判断一个点是否为内点
        self.bestModel = self.model(**{'data': data})  # 使用所有数据点拟合初始模型

    def loop(self):
        i = 0
        if self.num_pts > self.n_min:  # 如果数据点的数量大于最小抽样点数
            while i <= self.K:  # 迭代直到达到最大迭代次数
                inlr = np.random.choice(self.num_pts, self.n_min, replace=False)  # 从数据集中随机选择 n_min 个点作为内点
                loc_inlr = np.in1d(np.arange(0, self.num_pts), inlr)  # 标记内点的位置
                outlr = np.where(~loc_inlr)[0]  # 标记外点的位置
                potModel = self.model(**{'data': self.data[loc_inlr, :]})  # 使用内点拟合潜在模型
                listErr = potModel.fit_error(self.data[~loc_inlr, :])  # 计算潜在模型对外点的拟合误差
                inlr_num = np.size(inlr) + np.sum(listErr < self.T)  # 计算满足误差阈值的点数（内点数）
                if inlr_num > self.D:  # 如果满足内点阈值
                    pot_inlr = np.concatenate([inlr, outlr[listErr < self.T]], axis=0)  # 将满足误差阈值的外点加入内点集合
                    loc_pot_inlr = np.in1d(np.arange(0, self.num_pts), pot_inlr)  # 更新内点的位置
                    betterModel = self.model(**{'data': self.data[loc_pot_inlr, :]})  # 使用新的内点集合拟合更好的模型
                    if betterModel.error < self.bestModel.error:  # 如果新的模型误差更小
                        self.bestModel = betterModel  # 更新最优模型
                i += 1
        else:
            # 如果数据点数量小于等于最小抽样点数，直接返回使用所有数据点拟合的模型
            self.bestModel = self.model(**{'data': self.data})
        return self.bestModel  # 返回最优模型



# Helper functions
def rotation_2d(theta):
    # 返回逆时针方向的二维旋转矩阵
    c, s = np.cos(theta), np.sin(theta)
    H_rot = np.array([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1]])  # 构建旋转矩阵
    return H_rot

def trans_2d(cx, cy):
    # 返回二维平移矩阵
    H_trans = np.array([[1.0, 0.0, cx], [0.0, 1.0, cy], [0.0, 0.0, 1]])  # 构建平移矩阵
    return H_trans

def scale_2d(sx, sy):
    # 返回二维缩放矩阵
    H_scale = np.array([[sx, 0.0, 0.0], [0.0, sy, 0.0], [0.0, 0.0, 1]])  # 构建缩放矩阵
    return H_scale

def mypause(interval):
    # 暂停一定时间，用于动画显示
    backend = plt.rcParams['backend']
    if backend in matplotlib.rcsetup.interactive_bk:
        figManager = matplotlib._pylab_helpers.Gcf.get_active()
        if figManager is not None:
            canvas = figManager.canvas
            if canvas.figure.stale:
                canvas.draw()  # 重绘画布
            canvas.start_event_loop(interval)  # 开始事件循环，暂停 interval 时间
            return

def transformPoints(x, y, H):
    # 使用变换矩阵 H 变换点 (x, y)
    N = np.size(x)
    pts = np.stack([x, y, np.ones(N, )], axis=1) if (N > 1) else np.array([x, y, 1])  # 构建齐次坐标
    pts = H.dot(pts.T)  # 应用变换矩阵
    ox = pts[0, :] if N > 1 else pts[0]  # 提取变换后的 x 坐标
    oy = pts[1, :] if N > 1 else pts[1]  # 提取变换后的 y 坐标
    return (ox, oy)

def fillHoles(mask):
    # 填充掩码中的孔洞
    x_hole, y_hole = np.where(mask == 0)  # 找到孔洞位置
    for x, y in zip(x_hole, y_hole):
        # 用邻域平均值填充孔洞
        opts = mask[x-2:x+2, y-2:y+2].reshape(-1)
        if (not isinstance(opts, list)) & (opts.size != 0) & (sum(opts) != 0):
            mask[x, y] = np.round(np.mean(opts[opts != 0]))  # 用非零值的平均值填充
    return mask

def dummy_data(shape):
    # 生成虚拟数据
    num_of_frames = shape[0]
    true_list = [True] * num_of_frames
    false_list = [False] * num_of_frames

    data_dict = {}
    data_dict['is_bad'] = np.stack(true_list, axis=0)  # 所有帧都标记为坏的

    data_dict['mask'] = -1*np.ones(shape)  # 掩码初始化为 -1
    data_dict['image'] = np.zeros(shape, dtype=np.uint8)  # 图像初始化为 0
    data_dict['ds_num'] = -1*np.ones(num_of_frames)  # 初始化垃圾数据
    data_dict['pupil_center'] = -1*np.ones((num_of_frames, 2))  # 初始化瞳孔中心
    data_dict['iris_ellipse'] = -1*np.ones((num_of_frames, 5))  # 初始化虹膜椭圆
    data_dict['pupil_ellipse'] = -1*np.ones((num_of_frames, 5))  # 初始化瞳孔椭圆

    data_dict['im_num'] = -1*np.ones(num_of_frames)  # 初始化垃圾数据
    data_dict['archName'] = 'do_the_boogie!'  # 设置档案名

    # 独立的标志位
    data_dict['mask_available'] = np.stack(false_list, axis=0)
    data_dict['pupil_center_available'] = np.stack(false_list, axis=0)
    data_dict['iris_ellipse_available'] = np.stack(false_list, axis=0)
    data_dict['pupil_ellipse_available'] = np.stack(false_list, axis=0)
    return data_dict

def fix_batch(data_dict):
    # 修复批处理数据中的坏帧
    loc_bad = np.where(np.array(data_dict['is_bad']))[0]  # 找到坏帧位置
    loc_good = np.where(np.array(~data_dict['is_bad']))[0]  # 找到好帧位置

    for bad_idx in loc_bad.tolist():
        random_good_idx = int(np.random.choice(loc_good, 1).item())  # 随机选择一个好帧
        print('replacing {} with {}'.format(bad_idx, random_good_idx))

        for key in data_dict.keys():
            data_dict[key][bad_idx] = data_dict[key][random_good_idx]  # 用好帧数据替换坏帧
    return data_dict

def one_hot2dist(posmask):
    # 将单通道掩码转换为距离图
    h, w = posmask.shape
    mxDist = np.sqrt((h-1)**2 + (w-1)**2)  # 计算最大距离
    if np.any(posmask):
        assert len(posmask.shape) == 2
        res = np.zeros_like(posmask)
        posmask = posmask.astype(np.bool)
        if posmask.any():
            negmask = ~posmask
            res = distance(negmask) * negmask - (distance(posmask) - 1) * posmask  # 计算距离图
        res = res / mxDist  # 归一化
    else:
        res = np.zeros_like(posmask)  # 如果没有有效元素，返回全零矩阵
    return res

def label2onehot(Label):
    # 将标签转换为 one-hot 编码
    Label = (np.arange(4) == Label[..., None]).astype(np.uint8)  # 转换为 one-hot 编码
    Label = np.rollaxis(Label, 2)  # 轴转换
    return Label

def clean_mask(mask):
    '''
    输入：HXWXC 掩码
    输出：清理后的掩码
    通过收缩和扩展边缘图像清理掩码
    '''
    outmask = np.zeros_like(mask)
    classes_available = np.unique(mask)  # 获取掩码中存在的类别
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))  # 定义形态学操作的核
    for cls_idx in np.nditer(classes_available):
        I = 255 * np.uint8(mask == cls_idx)  # 转换为二值图像
        I = cv2.erode(I, kernel, iterations=1)  # 腐蚀操作
        I = cv2.dilate(I, kernel, iterations=1)  # 膨胀操作
        outmask[I.astype(np.bool)] = cls_idx  # 更新输出掩码
    return outmask



def simple_string(ele):
    '''
    ele: 需要去除所有特殊字符并转换为小写的字符串
    返回简化后的字符串
    '''
    if type(ele) is list:
        # 如果输入是列表，则对每个元素进行处理
        ele = [''.join(e.lower() for e in str(string_ele) if e.isalnum())
               for string_ele in ele]
    else:
        # 如果输入不是列表，则直接处理
        ele = ''.join(e.lower() for e in str(ele) if e.isalnum())
    return ele


def scale_by_ratio(data_dict, scale_ratio, train_with_mask):
    '''
    根据比例因子对数据进行缩放

    data_dict: 数据字典
    scale_ratio: 缩放比例因子
    train_with_mask: 是否使用掩码训练

    返回缩放后的数据字典
    '''

    more_features = False

    if 'Dikablis' or 'LPW' in data_dict['archName']:
        more_features = True

    num_of_frames = data_dict['image'].shape[0]

    dsize = (round(scale_ratio*data_dict['image'].shape[2]),
             round(scale_ratio*data_dict['image'].shape[1]))

    H = np.array([[scale_ratio, 0, 0],
                  [0, scale_ratio, 0],
                  [0, 0, 1]])

    image_list = []
    if train_with_mask:
        mask_list = []

    for i in range(num_of_frames):
        # 缩放图像
        image_list.append(cv2.resize(data_dict['image'][i], dsize,
                                        interpolation=cv2.INTER_LANCZOS4))
        if train_with_mask:
            # 缩放掩码
            mask_list.append(cv2.resize(data_dict['mask'][i],  dsize,
                                       interpolation=cv2.INTER_NEAREST))

        if data_dict['pupil_ellipse_available'][i]:
            # 转换瞳孔椭圆
            data_dict['pupil_ellipse'][i] =\
                my_ellipse(data_dict['pupil_ellipse'][i]).transform(H)[0][:-1]

        if data_dict['iris_ellipse_available'][i]:
            # 转换虹膜椭圆
            data_dict['iris_ellipse'][i] = \
                my_ellipse(data_dict['iris_ellipse'][i]).transform(H)[0][:-1]

        if data_dict['pupil_center_available'][i]:
            # 转换瞳孔中心
            data_dict['pupil_center'][i] = H[:2, :2].dot(data_dict['pupil_center'][i])

        if more_features:
            # 更多特征缩放
            data_dict['eyeball'][i] *= scale_ratio
            data_dict['iris_lm_2D'][i] *= scale_ratio
            data_dict['pupil_lm_2D'][i] *= scale_ratio


    data_dict['image'] = np.stack(image_list, axis=0)
    if train_with_mask:
        data_dict['mask'] = np.stack(mask_list, axis=0)

    return data_dict


def pad_to_shape(data_dict, to_size, mode='edge'):
    '''
    将图像填充并将椭圆转换为所需的形状。

    data_dict: 数据字典
    to_size: 所需形状的尺寸
    mode: 填充模式，默认为边缘复制

    返回填充后的数据字典
    '''

    assert len(data_dict['image'].shape) == 3, 'Image required to be grayscale and 1 more dimension to create the volume'
    num_of_frames , r_in, c_in = data_dict['image'].shape
    r_out, c_out = to_size

    inc_r = 0.5*(r_out - r_in)
    inc_c = 0.5*(c_out - c_in)

    for i in range(num_of_frames):
        # 对每个帧进行填充
        data_dict['mask'][i] = np.pad(data_dict['mask'][i],
                                   ((math.floor(inc_r), math.ceil(inc_r)),
                                    (math.floor(inc_c), math.ceil(inc_c))),
                                   mode='constant')
        data_dict['image'][i] = np.pad(data_dict['image'][i],
                                    ((math.floor(inc_r), math.ceil(inc_r)),
                                     (math.floor(inc_c), math.ceil(inc_c))),
                                    mode='edge')

        if data_dict['pupil_center_available'][i]:
            # 调整瞳孔中心的坐标
            data_dict['pupil_center'][i,:2] += np.array([inc_c, inc_r])

        if data_dict['pupil_ellipse_available'][i]:
            # 调整瞳孔椭圆的坐标
            data_dict['pupil_ellipse'][i,:2] += np.array([inc_c, inc_r])

        if data_dict['iris_ellipse_available'][i]:
            # 调整虹膜椭圆的坐标
            data_dict['iris_ellipse'][i,:2] += np.array([inc_c, inc_r])

    # 指定图像和帧数的维度
    output_size = (num_of_frames, r_out, c_out)

    assert data_dict['image'].shape == output_size, 'Padded image must match shape'

    return data_dict


class mod_scalar():
    '''
    线性模型的标量调整器
    '''

    def __init__(self, xlims, ylims):
        # 初始化模型参数
        self.slope = np.diff(ylims) / np.diff(xlims)
        self.intercept = ylims[1] - self.slope * xlims[1]
        self.xlims = xlims
        self.ylims = ylims

    def get_scalar(self, x_input):
        # 获取输入对应的标量值

        if x_input > self.xlims[1]:
            return self.ylims[1]

        if x_input < self.xlims[0]:
            return self.ylims[0]

        return self.slope * x_input + self.intercept



# def linVal(x, xlims, ylims, offset):
#     '''
#     Given xlims (x_min, x_max) and ylims (y_min, y_max), i.e, start and end,
#     compute the value of y=f(x). Offset contains the x0 such that for all x<x0,
#     y is clipped to y_min.
#     '''
#     if x < offset:
#         return ylims[0]
#     elif x > xlims[1]:
#         return ylims[1]
#     else:
#         y = (np.diff(ylims)/np.diff(xlims))*(x - offset)
#         return y.item()

def getValidPoints(LabelMat, isPartSeg=True, legacy=True):
    '''
    获取有效的数据点坐标。

    LabelMat: 标签矩阵
    isPartSeg: 是否是部分分割（针对不同的标签类型）
    legacy: 是否使用旧方法进行处理

    返回瞳孔点和虹膜点的坐标
    '''

    if legacy:
        # 将标签矩阵转换为0到255的值
        im = np.uint8(255 * LabelMat.astype(np.float32) / LabelMat.max())

        # 查找掩码和反转掩码的边缘，以确保处理内部和外部边缘的点
        edges = cv2.Canny(im, 50, 100) + cv2.Canny(255 - im, 50, 100)

    else:
        # 使用skimage中的便捷函数
        edges = find_boundaries(LabelMat)

    r, c = np.where(edges)

    # 初始化有效点的列表
    pupilPts, irisPts = [], []
    for loc in zip(c, r):
        temp = LabelMat[loc[1] - 1:loc[1] + 2, loc[0] - 1:loc[0] + 2]

        # 过滤掉无效点
        if isPartSeg:
            # 瞳孔点不能包含巩膜或皮肤
            condPupil = np.any(temp == 0) or np.any(temp == 1) or (temp.size == 0)

            # 虹膜点不能包含皮肤或瞳孔
            condIris = np.any(temp == 0) or np.any(temp == 3) or (temp.size == 0)
        else:

            # 瞳孔点不能包含皮肤
            condPupil = np.any(temp == 0) or (temp.size == 0)

            # 虹膜点不能包含瞳孔
            condIris = np.any(temp == 2) or (temp.size == 0)

        # 保留有效点
        pupilPts.append(np.array(loc)) if not condPupil else None
        irisPts.append(np.array(loc)) if not condIris else None

    pupilPts = np.stack(pupilPts, axis=0) if len(pupilPts) > 0 else []
    irisPts = np.stack(irisPts, axis=0) if len(irisPts) > 0 else []
    return pupilPts, irisPts


def stackall_Dict(D):
    '''
    将字典中的所有值堆叠起来。

    D: 输入的字典

    返回堆叠后的字典
    '''
    for key, value in D.items():
        if value:
            # 确保值不为空
            if type(D[key]) is list:
                print('Stacking: {}'.format(key))
                D[key] = np.stack(value, axis=0)
            elif type(D[key]) is dict:
                stackall_Dict(D[key])
    return D


def extract_datasets(subsets):
    '''
    提取数据集。

    subsets: 包含字符串数组的输入

    返回数据集的存在与id
    '''
    ds_idx = [str(ele).split('_')[0] for ele in np.nditer(subsets)]
    ds_present, ds_id = np.unique(ds_idx, return_inverse=True)
    return ds_present, ds_id


def convert_to_list_entries(data_dict):
    '''
    将数据字典转换为列表。

    data_dict: 输入的数据字典

    返回转换后的列表
    '''
    # 如果数据在torch中，则将其移回numpy
    for key, item in data_dict.items():
        if 'torch' in str(type(item)):
            data_dict[key] = item.detach().cpu().squeeze().numpy()

    # 根据字典生成空模板
    num_entries = data_dict[key].shape[0]
    out = []
    for ii in range(num_entries):
        out.append(
            {key: item[ii] for key, item in data_dict.items() if 'numpy' in str(type(item)) and 'latent' not in key})
    return out


def plot_images_with_annotations(data_dict,
                                 args,
                                 show=True,
                                 write=None,
                                 rendering=False,
                                 mask=False,
                                 subplots=None,
                                 is_predict=True,
                                 plot_annots=True,
                                 remove_saturated=True,
                                 is_list_of_entries=True,
                                 mode=None,
                                 epoch=0,
                                 batch=0
                                 ):
    '''
    绘制带注释的图像。

    data_dict: 数据字典
    args: 参数
    show: 是否显示图像
    write: 是否写入图像
    rendering: 是否进行渲染
    mask: 是否包含掩码
    subplots: 子图
    is_predict: 是否为预测
    plot_annots: 是否绘制注释
    remove_saturated: 是否移除饱和点
    is_list_of_entries: 是否为列表的条目
    mode: 模式
    epoch: epoch
    batch: batch

    返回None
    '''
    if not is_list_of_entries:
        list_data_dict = convert_to_list_entries(copy.deepcopy(data_dict))
    else:
        list_data_dict = data_dict

    if args['frames'] > 9:
        num_entries = 9
    else:
        num_entries = args['frames']

    if subplots:
        rows, cols = subplots
    else:
        rows = round(min(np.floor(10 ** 0.5), 4)) -1
        cols = round(min(np.floor(10 ** 0.5), 4)) -1

    fig, axs = plt.subplots(rows, cols, squeeze=True)

    idx = 0
    for i in range(rows):
        for j in range(cols):
            if (idx < num_entries):
                # 仅绘制列表范围内的条目
                if plot_annots:
                    out_image = draw_annots_on_image(list_data_dict[idx],
                                                     is_predict=is_predict,
                                                     mask=mask,
                                                     rendering=rendering,
                                                     intensity_maps=100,
                                                     remove_saturated=remove_saturated)
                    axs[i, j].imshow(out_image)
                else:
                    print('plot_annots=false')
                    axs[i, j].imshow(list_data_dict[0]['image'][idx])
            idx += 1
    if show:
        plt.show(block=False)

    if write:
        os.makedirs(os.path.dirname(write), exist_ok=True)
        fig.savefig(write, dpi=150, bbox_inches='tight')
    plt.close('all')
    return None


def draw_annots_on_image(data_dict,
                         remove_saturated=True,
                         intensity_maps=100,
                         pupil_index=2,
                         iris_index=1,
                         is_predict=True,
                         rendering=False,
                         mask = False):

    image = data_dict['image']
    image = image - image.min()
    image = (255*(image/image.max())).astype(np.uint8)
    out_image = np.stack([image]*3, axis=2)

    assert len(image.shape) == 2, 'Image must be grayscale'
    height, width = image.shape

    if remove_saturated:
        loc_image_non_sat = image <= (255-intensity_maps)
    else:
        loc_image_non_sat = image <= 255

    if rendering == False:
        # if is_predict or data_dict['pupil_center_available']:
        #     pupil_center = data_dict['pupil_center']
        #     [rr, cc] = draw.disk((pupil_center[1].clip(6, height-6),
        #                         pupil_center[0].clip(6, width-6)),
        #                         radius=5)
        #     out_image[rr, cc, :] = 255

        if is_predict or data_dict['pupil_ellipse_available']:
            pupil_ellipse = data_dict['pupil_ellipse']
            # print("Pupil: {}".format(type(pupil_ellipse)))
            loc_pupil = data_dict['mask'] == pupil_index
            out_image[..., 0] +=  (intensity_maps*loc_pupil*loc_image_non_sat).astype(np.uint8)
            out_image[..., 1] +=  (intensity_maps*loc_pupil*loc_image_non_sat).astype(np.uint8)
        #
        #     if np.all(np.abs(pupil_ellipse[0:4]) > 5):
        #         [rr_p, cc_p] = draw.ellipse_perimeter(round(pupil_ellipse[1]),
        #                                             round(pupil_ellipse[0]),
        #                                             round(pupil_ellipse[3]),
        #                                             round(pupil_ellipse[2]),
        #                                             orientation=pupil_ellipse[4],
        #                                             shape=image.shape)
        #         rr_p = rr_p.clip(6, image.shape[0]-6)
        #         cc_p = cc_p.clip(6, image.shape[1]-6)
        #
        #         out_image[rr_p, cc_p, 0] = 255

        if is_predict or data_dict['iris_ellipse_available']:
            iris_ellipse = data_dict['iris_ellipse']
            # print("Iris: {}".format(type(iris_ellipse)))
            loc_iris = data_dict['mask'] == iris_index
            out_image[..., 1] +=  (intensity_maps*loc_iris*loc_image_non_sat).astype(np.uint8)

            # if np.all(np.abs(iris_ellipse[0:4]) > 5):
            #     [rr_i, cc_i] = draw.ellipse_perimeter(round(iris_ellipse[1]),
            #                                         round(iris_ellipse[0]),
            #                                         round(iris_ellipse[3]),
            #                                         round(iris_ellipse[2]),
            #                                         orientation=iris_ellipse[4],
            #                                         shape=image.shape)
            #     rr_i = rr_i.clip(6, image.shape[0]-6)
            #     cc_i = cc_i.clip(6, image.shape[1]-6)
            #
            #     out_image[rr_i, cc_i, 2] = 255
        # loc_gaze = data_dict['gaze_vector'] == 1

        # out_image[..., 1] +=  (intensity_maps*loc_gaze*loc_image_non_sat).astype(np.uint8)
        # out_image[..., 2] +=  (intensity_maps*loc_gaze*loc_image_non_sat).astype(np.uint8)
    else:
        if mask:
            loc_pupil = data_dict['mask'] == pupil_index
            out_image[..., 0] +=  (intensity_maps*loc_pupil*loc_image_non_sat).astype(np.uint8)
            out_image[..., 1] +=  (intensity_maps*loc_pupil*loc_image_non_sat).astype(np.uint8)

            loc_iris = data_dict['mask'] == iris_index
            out_image[..., 1] +=  (intensity_maps*loc_iris*loc_image_non_sat).astype(np.uint8)
        else:
            loc_gaze = data_dict['gaze_img'] == 1
            out_image[..., 1] += (intensity_maps * loc_gaze * loc_image_non_sat).astype(np.uint8)
            out_image[..., 2] += (intensity_maps * loc_gaze * loc_image_non_sat).astype(np.uint8)

    # # 获取掩码数据
    # mask = data_dict['mask']
    #
    # # 绘制掩码图像
    # # plt.figure(figsize=(6, 6))
    # # plt.imshow(mask, cmap='gray')  # 使用灰度色彩映射显示掩码
    # # plt.colorbar()  # 添加颜色条，用于显示掩码值的范围
    # # plt.title('Mask Visualization')  # 设置标题
    # # plt.axis('off')  # 关闭坐标轴显示
    # # plt.show()  # 显示掩码图像
    #
    # plt.imsave('output_image.png', out_image)
    # plt.imshow(out_image)
    # plt.show()
    # plt.savefig(out_image)
    return out_image.astype(np.uint8)


def merge_two_dicts(dict_A, dict_B):
    '''

    Parameters
    ----------
    dict_A : DICT
        Regalar dictionary.
    dict_B : DICT
        Regular dictionary.

    Returns
    -------
    dict_C : DICT.

    '''
    dict_C = dict_A.copy()
    dict_C.update(dict_B)
    return dict_C


def get_ellipse_info(param, H, cond):
    '''
    获取椭圆信息。

    Parameters
    ----------
    param : np.array
        给定椭圆参数，返回以下内容：
            a) 沿周边的点
            b) 标准化的椭圆参数
    H: np.array 3x3
        将椭圆转换为标准化坐标的归一化矩阵
    cond : bool
        条件判断

    Returns
    -------
    normParam : np.array
        标准化的椭圆参数
    elPts : np.array
        沿椭圆周边的点
    '''

    if cond:
        # 将椭圆参数转换为标准化坐标，并获取沿周边的点
        norm_param = my_ellipse(param).transform(H)[0][:-1]  # 我们不需要面积
        elPts = my_ellipse(norm_param).generatePoints(50, 'equiAngle') # 规则点
        elPts = np.stack(elPts, axis=1)

        # 修正椭圆的轴和角度
        norm_param = fix_ellipse_axis_angle(norm_param)

    else:
        # 椭圆不存在时返回-1
        norm_param = -np.ones((5, ))
        elPts = -np.ones((8, 2))
    return elPts, norm_param


def fix_ellipse_axis_angle(ellipse):
    '''
    修正椭圆的轴和角度。

    ellipse : np.array
        椭圆参数

    Returns
    -------
    ellipse : np.array
        修正后的椭圆参数
    '''
    ellipse = copy.deepcopy(ellipse)
    if ellipse[3] > ellipse[2]:
        # 如果短轴大于长轴，则交换长短轴
        ellipse[[2, 3]] = ellipse[[3, 2]]
        ellipse[4] += np.pi / 2

    if ellipse[4] > np.pi:
        ellipse[4] += -np.pi
    elif ellipse[4] < 0:
        ellipse[4] += np.pi

    return ellipse


# Plot segmentation output, pupil and iris ellipses
def plot_segmap_ellpreds(image, seg_map, pupil_ellipse, iris_ellipse, thres=50, plot_ellipses=True):
    '''
    绘制分割图和瞳孔/虹膜椭圆。

    Parameters
    ----------
    image : np.array
        输入图像
    seg_map : np.array
        分割图
    pupil_ellipse : np.array
        瞳孔椭圆参数
    iris_ellipse : np.array
        虹膜椭圆参数
    thres : int, optional
        阈值，默认为50
    plot_ellipses : bool, optional
        是否绘制椭圆，默认为True

    Returns
    -------
    out_image : np.array
        输出图像
    '''

    # 根据分割图像获取虹膜和瞳孔位置
    loc_iris = seg_map == 1
    loc_pupil = seg_map == 2

    # 创建与输入图像相同大小的输出图像
    out_image = np.stack([image]*3, axis=2)

    # 根据阈值处理饱和区域
    loc_image_non_sat = image < (255-thres)

    # 给虹膜添加绿色
    out_image[..., 1] = out_image[..., 1] + thres*loc_iris*loc_image_non_sat
    rr, cc = np.where(loc_iris & ~loc_image_non_sat)

    # 对于饱和的虹膜位置，修正为绿色
    out_image[rr, cc, 0] = 0
    out_image[rr, cc, 1] = 255
    out_image[rr, cc, 2] = 0

    # 给瞳孔添加黄色
    out_image[..., 0] = out_image[..., 0] + thres*loc_pupil*loc_image_non_sat
    out_image[..., 1] = out_image[..., 1] + thres*loc_pupil*loc_image_non_sat
    rr, cc = np.where(loc_pupil & ~loc_image_non_sat)

    # 对于饱和的瞳孔位置，修正为黄色
    out_image[rr, cc, 0] = 255
    out_image[rr, cc, 1] = 255
    out_image[rr, cc, 2] = 0

    # 绘制虹膜椭圆和瞳孔椭圆
    if plot_ellipses:
        # 绘制虹膜椭圆
        [rr_i, cc_i] = draw.ellipse_perimeter(round(iris_ellipse[1]),
                                              round(iris_ellipse[0]),
                                              round(iris_ellipse[3]),
                                              round(iris_ellipse[2]),
                                              orientation=iris_ellipse[4])

        # 绘制瞳孔椭圆
        [rr_p, cc_p] = draw.ellipse_perimeter(round(pupil_ellipse[1]),
                                              round(pupil_ellipse[0]),
                                              round(pupil_ellipse[3]),
                                              round(pupil_ellipse[2]),
                                              orientation=pupil_ellipse[4])

        # 限制边界内的显示
        rr_i = rr_i.clip(6, image.shape[0]-6)
        rr_p = rr_p.clip(6, image.shape[0]-6)
        cc_i = cc_i.clip(6, image.shape[1]-6)
        cc_p = cc_p.clip(6, image.shape[1]-6)

        # 将虹膜和瞳孔椭圆标记为蓝色和红色
        out_image[rr_i, cc_i, ...] = np.array([0, 0, 255])
        out_image[rr_p, cc_p, ...] = np.array([255, 0, 0])

    return out_image.astype(np.uint8)



# Data extraction helpers
def generateEmptyStorage(name, subset):
    '''
    生成一个带有所有相关字段的空字典。
    这有助于在所有数据集之间保持一致性。

    Parameters
    ----------
    name : str
        数据集名称
    subset : str
        数据集子集名称

    Returns
    -------
    Data : dict
        包含所有字段的空字典
    Key : dict
        包含所有字段的空字典，用于记录关键信息
    '''
    # 创建空字典 Data，包含所有相关字段
    Data = {
        'Images': [],  # 灰度图像
        'event': [],  # 灰度图像
        'dataset': name,  # 数据集
        'subset': subset,  # 子集
        'resolution': [],  # 图像分辨率
        'archive': [],  # H5文件名
        'Info': [],  # 原始图像路径
        'Masks': [],  # 掩模
        'Masks_pupil_in_iris': [],  # 虹膜内的瞳孔掩模
        'Masks_noSkin': [],  # 仅包含虹膜和瞳孔的掩模
        'subject_id': [],  # 主体ID（如果可用）
        'Fits': {'pupil': [], 'iris': []},  # 瞳孔和虹膜拟合
        'pupil_loc': [],  # 瞳孔位置
        'pupil_in_iris_loc': [],  # 虹膜内瞳孔位置
        'Eyeball': [],  # 眼球
        'Gaze_vector': [],  # 视线向量
        'pupil_lm_2D': [],  # 瞳孔的2D关键点
        'pupil_lm_3D': [],  # 瞳孔的3D关键点
        'iris_lm_2D': [],  # 虹膜的2D关键点
        'iris_lm_3D': [],  # 虹膜的3D关键点
        'timestamp': []  # 时间戳
    }

    # 创建空字典 Key，用于记录关键信息
    Key = {
        'dataset': name,  # 数据集
        'subset': subset,  # 子集
        'resolution': [],  # 图像分辨率
        'subject_id': [],  # 主体ID（如果可用）
        'archive': [],  # H5文件名
        'Info': [],  # 原始图像路径
        'Fits': {'pupil': [], 'iris': []},  # 瞳孔和虹膜拟合
        'pupil_loc': [],  # 瞳孔位置
        'pupil_in_iris_loc': []  # 虹膜内瞳孔位置
    }

    return Data, Key



def plot_2D_hist(x, y, x_lims, y_lims, str_save='temp.jpg', axs=None):
    '''
    绘制二维直方图。

    Parameters
    ----------
    x : array_like
        x轴数据
    y : array_like
        y轴数据
    x_lims : tuple
        x轴范围
    y_lims : tuple
        y轴范围
    str_save : str, optional
        保存图像的文件路径，默认为'temp.jpg'
    axs : matplotlib.axes.Axes, optional
        要绘制直方图的轴，默认为None

    Returns
    -------
    None
    '''
    # 计算直方图
    H, xedges, yedges = np.histogram2d(x, y,
                                        bins=64,
                                        range=[x_lims, y_lims],
                                        density=False)

    # 对数尺度以获得更好的视觉效果
    H = np.log(H + 1)

    # 归一化以进行显示
    H = H - H.min()
    H = H / H.max()

    if axs is None:
        # 如果没有提供轴，则创建一个新的图像并显示直方图
        fig, axs = plt.subplots()
        axs.imshow(H,
                   interpolation='lanczos',
                   origin='upper', extent=tuple(x_lims + y_lims))
        # 保存图像
        fig.savefig(str_save, dpi=600, transparent=True, bbox_inches='tight')
    else:
        # 在提供的轴上显示直方图
        axs.imshow(H,
                   interpolation='lanczos',
                   origin='upper', extent=tuple(x_lims + y_lims))


def measure_contrast(image, mask=None):
    '''
    计算图像的对比度和掩模区域内的对比度直方图。

    Parameters
    ----------
    image : array_like
        输入图像
    mask : array_like, optional
        图像掩模，默认为None

    Returns
    -------
    hist : array_like
        图像的对比度直方图
    hist_lb : array_like
        掩模区域内的对比度直方图
    '''
    size = (3, 3)  # 窗口大小，即3x3窗口

    num_patches = (image.shape[0] - size[0] + 1) * (image.shape[1] - size[1] + 1)
    # 切分图像为小块
    patches = np.lib.stride_tricks.sliding_window_view(image, size)
    patches = patches.reshape(num_patches, 9)

    try:
        # 计算每个小块的标准差，即对比度
        contrast = patches.std(axis=1)
    except:
        import pdb;
        pdb.set_trace()

    # 计算图像的对比度直方图
    hist = scipy.ndimage.histogram(contrast.flatten(), 0, 32, 64) / num_patches

    hist_lb = []
    if mask is not None:
        # 如果提供了掩模，则计算掩模区域内的对比度直方图
        label_patches = mask[size[0] - 2:-size[0] + 2,
                         size[1] - 2:-size[1] + 2]
        label_patches = label_patches.reshape(num_patches, )

        temp = []
        for label in range(0, 3):
            loc = label_patches == label
            if np.sum(loc) >= 1:
                temp.append(scipy.ndimage.histogram(contrast[loc],
                                                     0, 64, 64) / np.sum(loc))
            else:
                temp.append(np.zeros(64, ))
        temp = np.stack(temp, axis=0)
        hist_lb.append(temp)

    return hist, np.stack(hist_lb, axis=0)


def image_contrast(image, scales=[1, ], by_category=None):
    '''
    计算图像的对比度。

    Parameters
    ----------
    image : array_like
        输入图像
    scales : list, optional
        要应用的缩放比例列表，默认为[1, ]
    by_category : array_like, optional
        按类别的掩模图像，默认为None

    Returns
    -------
    contrast : array_like
        图像的对比度数组
    contrast_by_class : array_like
        按类别的图像对比度数组
    '''
    contrast = []
    contrast_by_class = []

    for scale in scales:
        dsize = tuple(int(ele * scale) for ele in image.shape)
        data = cv2.resize(image, dsize, interpolation=cv2.INTER_LANCZOS4)

        if by_category is not None:
            mask = cv2.resize(by_category, dsize,
                              interpolation=cv2.INTER_NEAREST)
        else:
            mask = None

        out = measure_contrast(data, mask)

        contrast.append(out[0])

        if by_category is not None:
            contrast_by_class.append(out[1])

    contrast = np.stack(contrast, axis=0)
    if by_category is not None:
        contrast_by_class = np.stack(contrast_by_class, axis=0)

    return contrast, contrast_by_class


def construct_mask_from_ellipse(ellipses, res):
    '''
    根据椭圆构造掩模。

    Parameters
    ----------
    ellipses : array_like
        椭圆参数数组
    res : tuple
        掩模的分辨率

    Returns
    -------
    mask : array_like
        掩模数组
    '''
    if len(ellipses.shape) == 1:
        ellipses = ellipses[np.newaxis, np.newaxis, :]
        B = 1
        F = 1
    else:
        B = ellipses.shape[0]
        F = ellipses.shape[1]

    mask = np.zeros((B, F) + res)
    for b in range(B):
        for frame in range(F):
            ellipse = ellipses[b, frame, ...].tolist()
            [rr, cc] = draw.ellipse(round(ellipse[1]),
                                    round(ellipse[0]),
                                    round(ellipse[3]),
                                    round(ellipse[2]),
                                    shape=res,
                                    rotation=-ellipse[4])
            rr = np.clip(rr, 0, res[0] - 1)
            cc = np.clip(cc, 0, res[1] - 1)
            mask[b, frame, rr, cc] = 1
    return mask.astype(bool)


def contruct_ellipse_from_mask(mask, pupil_c, iris_c):
    '''
    根据掩模构造椭圆。

    Parameters
    ----------
    mask : array_like
        输入的掩模数组
    pupil_c : array_like
        瞳孔中心坐标
    iris_c : array_like
        虹膜中心坐标

    Returns
    -------
    pupil_ellipse : array_like
        瞳孔椭圆参数
    iris_ellipse : array_like
        虹膜椭圆参数
    '''
    assert mask.max() <= 3, 'Highest class label cannot exceed 3'
    assert mask.min() >= 0, 'Lowest class label cannot be below 0'
    pts_pupil, pts_iris = getValidPoints(mask, isPartSeg=False, legacy=False)

    irisFit = ElliFit(data=pts_iris)
    pupilFit = ElliFit(data=pts_pupil)

    iris_ellipse = irisFit.model[:5]
    # iris_ellipse[:2] = iris_c

    pupil_ellipse = pupilFit.model[:5]
    # pupil_ellipse[:2] = pupil_c

    return pupil_ellipse, iris_ellipse


def generate_rend_masks(rend_dict, H, W, iterations):
    '''
    生成渲染掩模。

    Parameters
    ----------
    rend_dict : dict
        渲染字典
    H : int
        高度
    W : int
        宽度
    iterations : int
        迭代次数

    Returns
    -------
    rend_dict : dict
        更新后的渲染字典
    '''
    channel = 3
    rendering_mask = torch.zeros((iterations, channel, H, W))
    mask_gaze = torch.zeros((iterations, H, W))

    for frame in range(iterations):
        mask_pupil = np.zeros((H, W))
        mask_iris = np.zeros((H, W))
        gaze_line = np.zeros((H, W))

        # 生成包含瞳孔和虹膜的分割掩模
        pupil_perimeter_point = rend_dict['pupil_UV'][frame][rend_dict['edge_idx_pupil']['outline']].to(
            int).detach().cpu().numpy()
        iris_perimeter_point = rend_dict['iris_UV'][frame][rend_dict['edge_idx_iris']['outline']].to(
            int).detach().cpu().numpy()

        cv2.drawContours(mask_pupil, [pupil_perimeter_point], -1, 2, -1)
        cv2.drawContours(mask_iris, [iris_perimeter_point], -1, 1, -1)

        # combined_mask = cv2.addWeighted(mask_pupil, 0.5, mask_iris, 0.5, 0)
        # cv2.imshow('Combined Mask', combined_mask)
        # cv2.waitKey(0)
        # cv2.destroyAllWindows()

        rendering_mask[frame, 1] = torch.from_numpy(mask_pupil)
        rendering_mask[frame, 2] = torch.from_numpy(mask_iris)

        # 生成注视线向量和眼球
        eyeball_circle_points = rend_dict['eyeball_circle'][frame].to(int).detach().cpu().numpy()
        center_pupil_gaze = rend_dict['pupil_UV'][frame][0].to(int).detach().cpu().numpy()
        center_eyeball_gaze = rend_dict['eyeball_c_UV'][frame].to(int).detach().cpu().numpy()

        cv2.drawContours(gaze_line, [eyeball_circle_points], -1, 1, 3)
        cv2.line(gaze_line, (center_eyeball_gaze[0], center_eyeball_gaze[1]),
                 (center_pupil_gaze[0], center_pupil_gaze[1]),
                 1, 3)

        mask_gaze[frame] = torch.from_numpy(gaze_line)

    rend_dict['predict'] = rendering_mask
    rend_dict['mask_gaze'] = mask_gaze

    return rend_dict


def assert_torch_invalid(X, string):
    '''
    断言 Torch 张量是否有效。

    Parameters
    ----------
    X : torch.Tensor
        要检查的张量
    string : str
        异常消息字符串
    '''
    assert not torch.isnan(X).any(), print('NaN problem [' + string + ']')
    assert not torch.isinf(X).any(), print('inf problem [' + string + ']')
    assert torch.all(torch.isfinite(X)), print('Some elements not finite problem [' + string + ']')
    assert X.numel() > 0, print('Empty tensor problem [' + string + ']')




