#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This started as a copy of https://bitbucket.org/RSKothari/multiset_gaze/src/master/ 
with additional changes and modifications to adjust it to our implementation. 

Copyright (c) 2021 Rakshit Kothari, Aayush Chaudhary, Reynold Bailey, Jeff Pelz, 
and Gabriel Diaz
"""

import re
import os
import traceback

import cv2
import h5py
import copy
import torch
import pickle
import random

import numpy as np
import scipy.io as scio

import matplotlib.pyplot as plt

from helperfunctions.data_augment import augment, flip
from torch.utils.data import Dataset

from helperfunctions.hfunctions import simple_string, one_hot2dist
from helperfunctions.hfunctions import pad_to_shape, get_ellipse_info
from helperfunctions.hfunctions import extract_datasets, scale_by_ratio
from helperfunctions.hfunctions import fix_ellipse_axis_angle, dummy_data

from Visualitation_TEyeD.gaze_estimation import generate_gaze_gt 

from helperfunctions.utils import normPts

from sklearn.model_selection import StratifiedKFold, train_test_split

os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"  # Deactive file locking


class MaskToTensor(object):
    def __call__(self, img):
        return torch.from_numpy(np.array(img, dtype=np.int32)).long()


class DataLoader_riteyes(Dataset):
    def __init__(self,
                 dataDiv_Obj,
                 path2data,
                 cond,
                 augFlag=False,
                 size=(480, 640),
                 fold_num=0,
                 num_frames=4,
                 sort='random',
                 args=None,
                 scale=False):

        self.mode = cond
        self.init_frames = num_frames


        self.ellseg = args['net_ellseg_head']
        self.train_with_mask = args['loss_w_rend_pred_2_gt_edge'] \
                                or args['loss_w_rend_gt_2_pred'] \
                                or args['loss_w_rend_pred_2_gt'] \
                                or args['net_ellseg_head']

        cond = 'train_idx' if 'train' in cond else cond
        cond = 'valid_idx' if 'valid' in cond else cond
        cond = 'test_idx' if 'test' in cond else cond
        
        # Operational variables
        self.arch = dataDiv_Obj.arch  # Available archives
        self.size = size  # Expected size of images
        self.scale = scale
        self.imList = dataDiv_Obj.folds[fold_num][cond]  # Image list
        self.augFlag = augFlag  # Augmentation flag
        self.equi_var = True  # Default is always True
        self.path2data = path2data  # Path to expected H5 files
        self.more_feature = False

        #  You can specify which augs you want as input to augment
        # TODO [0, 1, 8, 9, 11]
        self.augger = augment(choice_list=[0, 1, 11], mask_available=self.train_with_mask) if augFlag else []
        self.flipper = flip()

        # Get dataset index by archive ID
        ds_present, ds_index = extract_datasets(self.arch[self.imList[:, 1]])
        self.imList = np.hstack([self.imList, ds_index[:, np.newaxis]])
        self.fileObjs = {}

        avail_ds, counts = np.unique(ds_index, return_counts=True)

        # Repeat poorly represented datasets such that equal number of images
        # exist per dataset
        # TODO MODIFY TO SUPPORT MORE THAN ONE DATASET
        # TODO In case of more than one dataset we should have
        # the same number of images per dataset
        """if len(counts) > 1:
            extra_samples = []
            for ii, ds_itr in enumerate(avail_ds.tolist()):
                num_more_images_needed = max(counts) - counts[ii]
                if num_more_images_needed > 0:
                    loc = np.where(self.imList[:, -1] == ds_itr)[0]
                    extra_loc = np.random.choice(loc,
                                                 size=num_more_images_needed)
                    extra_samples.append(self.imList[extra_loc, :])

            extra_samples = np.concatenate(extra_samples, axis=0)
            self.imList = np.concatenate([self.imList, extra_samples])
            len_cond = self.imList.shape[0] == len(avail_ds)*max(counts)
            assert len_cond, 'Samples must equal N X the max samples present' """

        # 删除行以创建包含十个图像的批次，因此数据集中的图像数量应为/10=0
        del_rows = self.imList.shape[0] % self.init_frames
        if del_rows != 0:
            self.imList = self.imList[:-del_rows]  # 删除多余的行以确保批次大小为初始帧数的倍数
        self.imList = np.reshape(self.imList, (-1, self.init_frames, 3))  # 重新整形图像列表以匹配批次大小

        for i in range(self.imList.shape[0]):
            if i >= self.imList.shape[0]:
                continue
            # 如果图像序列的第一个通道的值不全相同，则删除该序列
            if not all(element == self.imList[i, 0, 1] for element in self.imList[i, ..., 1]):
                self.imList = np.delete(self.imList, i, axis=0)

        self.sort(sort)  # 对图像进行排序

        if (cond == 'train_idx'):
            # 根据训练数据的百分比确定要保留的图像数量
            num_of_entries = int(self.imList.shape[0] * args['train_data_percentage'])
            self.imList = self.imList[:num_of_entries]  # 保留指定数量的图像作为训练数据

        print(f'Split: {self.mode}')
        print(f'Num: {self.imList.shape[0]*self.imList.shape[1]}')
        print(f'Train perc.: {args["train_data_percentage"]*100.0}%')

    def sort(self, sort, batch_size=None):

        if sort=='ordered':
            # Completely ordered
            loc = np.unique(self.imList,
                            return_counts=True,
                            axis=0)
            #print('Warning. Non-unique file list.') if np.any(loc[1]!=1) else print('Sorted list')
            self.imList = loc[0]
        elif sort == 'nothing':
            pass

        elif sort=='semiordered':
            # Randomize first, then sort by archNum
            self.sort(sort='random')
            loc = np.argsort(self.imList[:, 1, 1])
            self.imList = self.imList[loc, :]

        elif sort=='random':
            # Completely random selection. DEFAULT.
            loc = np.random.permutation(self.imList.shape[0])
            self.imList = self.imList[loc, :]

        elif sort=='mutliset_random':
            # 先随机化，然后按BS / num_sets重新排列每组图像。
            # 这可确保每个数据集中的图像数量相等
            # 每批读取的读入数。
            self.sort('random')
            avail_ds, counts = np.unique(self.imList[:, 1, 2],
                                         return_counts=True)
            temp_imList = []
            for ds_itr in np.nditer(avail_ds):
                loc = self.imList[:, 1, 2] == ds_itr
                temp_imList.append(self.imList[loc, :])
            temp_imList = np.stack(temp_imList, axis=1).reshape(-1, 10, 3)
            print(temp_imList.shape)
            print(self.imList.shape)
            assert temp_imList.shape == self.imList.shape, 'Incorrect reshaping'
            self.imList = temp_imList

        elif sort=='one_by_one_ds':
            # Randomize first, then rearrange such that each BS contains image
            # from a single dataset
            self.sort('random')
            avail_ds, counts = np.unique(self.imList[:, 1, 2],
                                         return_counts=True)

            # Create a list of information for each individual dataset
            # present within the selection
            temp_imList = []
            for ds_itr in np.nditer(avail_ds):
                loc = self.imList[:, 1, 2] == ds_itr
                temp_imList.append(self.imList[loc, :])

            cond = True
            counter = 0

            imList = [] # Blank initialization
            while cond:
                counter+=1
                # Keep extracting batch_size elements from each entry
                ds_order = random.sample(range(avail_ds.max()),
                                         avail_ds.max())

                for i in range(avail_ds.max()):
                    idx = ds_order[i] if ds_order else 0
                    start = (counter-1)*batch_size
                    stop = counter*batch_size

                    if stop < temp_imList[idx].shape[0]:
                        imList.append(temp_imList[idx][start:stop, ...])
                    else:
                        # A particular dataset has been completely sampled
                        counter = 0
                        cond = False # Break out of main loop
                        break # Break out of inner loop
            self.imList = np.concatenate(imList, axis=0)

        else:
            import sys
            sys.exit('Incorrect sorting options')

    def __len__(self):
        return self.imList.shape[0]

    def __del__(self, ):
        for entry, h5_file_obj in self.fileObjs.items():
            h5_file_obj.close()

    def __getitem__(self, idx):
        '''
        阅读图像和所有必要的信息来源。
        返回包含所有信息的字典。
        图像和掩模是包含N个灰度图像的体积
        '''
        try:
            numClasses = 3  # 定义类别数量为3
            data_dict = self.readEntry_new(idx)  # 读取新的数据条目
            # data_dict = pad_to_shape(data_dict, to_size=(240, 320))  # 将数据填充到指定大小（可选）

            if self.scale:
                # 如果设置了缩放比例，则按比例缩放数据
                data_dict = scale_by_ratio(data_dict, self.scale, self.train_with_mask)

            # 如果启用了数据增强，则应用数据增强
            data_dict = self.augger(data_dict) if self.augFlag else data_dict

            # 确认图像亮度值在0到255之间
            assert data_dict['image'].max() <= 255, '最大亮度应 <=255'
            assert data_dict['image'].min() >= 0, '最小亮度应 >=0'
            # 确认图像尺寸为320x240
            assert data_dict['image'].shape[2] == 320 and data_dict['image'].shape[1] == 240, '之前的功能未实现，不同于320*240的尺寸'

            if (np.random.rand(1) > 0.8) and self.augFlag:
                # 如果启用了数据增强，则以0.5的概率翻转图像
                data_dict = self.flipper(data_dict, self.train_with_mask)
            # print('读取和处理数据时成功！')
            # print(self.imList[idx, :, 0])
            # print(self.imList[idx, :, 1])

        except Exception:
            print('读取和处理数据时出错！')
            traceback.print_exc()  # 打印完整的异常信息
            data_dict = self.readEntry_new(idx)  # 重新读取数据条目
            if (data_dict['image'].shape[2] != 320 and data_dict['image'].shape[1] != 240):
                print('之前的功能未实现，不同于320*240的尺寸')
            im_num = self.imList[idx, :, 0]  # 获取图像编号
            arch_num = self.imList[idx, :, 1]  # 获取存档编号
            archStr = self.arch[arch_num[0]]  # 获取存档名称
            print('错误的样本编号: {}'.format(im_num))
            print('错误的存档编号: {}'.format(arch_num))
            print('错误的存档名称: {}'.format(archStr))
            data_dict = dummy_data(shape=(self.init_frames, 480 // 2, 640 // 2))  # 生成虚拟数据

        num_of_frames = self.init_frames  # 获取初始帧数
        height = data_dict['image'].shape[1]  # 获取图像高度
        width = data_dict['image'].shape[2]  # 获取图像宽度

        for i in range(num_of_frames):
            # 修正椭圆轴角度
            data_dict['pupil_ellipse'][i] = fix_ellipse_axis_angle(data_dict['pupil_ellipse'][i])
            data_dict['iris_ellipse'][i] = fix_ellipse_axis_angle(data_dict['iris_ellipse'][i])

        if self.train_with_mask:
            spatial_weights_list = []
            distance_map_list = []
            if self.more_feature:
                # 修改标签，移除巩膜类
                if np.all(data_dict['mask_available']):
                    data_dict['mask'][data_dict['mask'] == 1] = 0  # 将背景移到0
                    data_dict['mask'][data_dict['mask'] == 2] = 1  # 将虹膜移到1
                    data_dict['mask'][data_dict['mask'] == 3] = 2  # 将瞳孔移到2

            if self.ellseg:
                for i in range(num_of_frames):
                    if data_dict['mask_available'][i]:
                        # 为每个类别计算距离图用于表面损失
                        # 计算边缘权重图
                        spatial_weights = cv2.Canny(data_dict['mask'][i].astype(np.uint8), 0, 1) / 255
                        spatial_weights_list.append(1 + cv2.dilate(spatial_weights, (3, 3), iterations=1) * 20)

                        # 仅计算虹膜和瞳孔的距离图
                        # 瞳孔：2，虹膜：1，其他：0
                        distance_map = np.zeros(((3,) + data_dict['image'][i].shape))
                        for k in range(0, numClasses):
                            distance_map[k, ...] = one_hot2dist(data_dict['mask'][i].astype(np.uint8) == k)
                        distance_map_list.append(distance_map)
                    else:
                        distance_map_list.append(np.zeros(((3,) + data_dict['image'][i].shape)))
                        spatial_weights_list.append(np.zeros_like(data_dict['mask'][i]))

                data_dict['distance_map'] = np.stack(distance_map_list, axis=0)  # 堆叠距离图
                data_dict['spatial_weights'] = np.stack(spatial_weights_list, axis=0)  # 堆叠空间权重图

        # 如果数据字典是通过虚拟数据生成器创建的，跳过这一步，因为标准差为零
        # 使用图像编号(id)来检查何时使用虚拟数据生成器
        if (np.all(data_dict['im_num'] >= 0)):
            pic = data_dict['image']
            data_to_torch = (pic - pic.mean()) / pic.std()  # 标准化图像数据

        # 如果图像中存在无穷大或NaN值，将数据设置为零
        if np.any(np.isinf(data_dict['image'])) or np.any(np.isnan(data_dict['image'])):
            print('NaN')
            data_to_torch = np.zeros_like(data_dict['image']).astype(np.uint8)
            data_dict['is_bad'][i] = np.stack([True] * num_of_frames, axis=0)

        # 将Groundtruth标注掩码转换为长整型张量
        if self.train_with_mask:
            data_dict['mask'] = MaskToTensor()(data_dict['mask']).to(torch.long)

        data_dict['image'] = data_to_torch  # 将图像数据转换为张量

        # 生成标准化的瞳孔和虹膜信息
        if self.equi_var:
            sc = max([width, height])
            H = np.array([[2 / sc, 0, -1], [0, 2 / sc, -1], [0, 0, 1]])
        else:
            H = np.array([[2 / width, 0, -1], [0, 2 / height, -1], [0, 0, 1]])

        iris_ellipse_norm_list = []
        pupil_ellipse_norm_list = []
        pupil_center_norm_list = []
        for i in range(num_of_frames):
            if not data_dict['is_bad'][i]:
                iris_ellipse_norm_list.append(get_ellipse_info(data_dict['iris_ellipse'][i], H,
                                                               data_dict['iris_ellipse_available'][i])[1])
                pupil_ellipse_norm_list.append(get_ellipse_info(data_dict['pupil_ellipse'][i], H,
                                                                data_dict['pupil_ellipse_available'][i])[1])

                # 生成标准化的瞳孔中心位置
                pupil_center_norm_list.append(normPts(data_dict['pupil_center'][i],
                                                      np.array([width, height]),
                                                      by_max=self.equi_var))
            else:
                iris_ellipse_norm_list.append(-1 * np.ones((5,)))
                pupil_center_norm_list.append(-1 * np.ones((2,)))
                pupil_ellipse_norm_list.append(-1 * np.ones((5,)))

        data_dict['iris_ellipse_norm'] = np.stack(iris_ellipse_norm_list, axis=0)  # 堆叠标准化的虹膜椭圆信息
        data_dict['pupil_ellipse_norm'] = np.stack(pupil_ellipse_norm_list, axis=0)  # 堆叠标准化的瞳孔椭圆信息
        data_dict['pupil_center_norm'] = np.stack(pupil_center_norm_list, axis=0)  # 堆叠标准化的瞳孔中心信息
        return data_dict  # 返回处理后的数据字典

    # 读取一条记录并为神经网络创建3D输入卷

    #read entry works. Create the 3D volume for input to the nn
    def readEntry_new(self, idx):
        '''
        Read a number of sequential images and all their groundtruths using partial loading
        Mask annotations. This is followed by OpenEDS definitions:
            0 -> Background
            1 -> Sclera (if available)
            2 -> Iris
            3 -> Pupil
        '''
        im_num = self.imList[idx, ..., 0]
        set_num = self.imList[idx, ..., 2]
        arch_num = self.imList[idx, ..., 1]

        archStr = self.arch[arch_num[0]]
        archName = archStr.split(':')[0]

        # Use H5 files already open for data I/O. This enables catching.
        if archName not in self.fileObjs.keys():
            self.fileObjs[archName] = h5py.File(os.path.join(self.path2data,
                                                             str(archName)+'.h5'),
                                                'r', swmr=True)
        f = self.fileObjs[archName]

        self.more_feature = False

        num_of_frames = self.init_frames

        # Read information
        image = f['Images'][im_num, ...]

        # Get pupil center
        if f['pupil_loc'].__len__() != 0:
            pupil_center = f['pupil_loc'][im_num, ...]
            pupil_center_available = [True] * num_of_frames
        else:
            pupil_center_available = [False] * num_of_frames
            pupil_center = -np.ones(num_of_frames, 2, )

        # Get mask without skin
        if self.train_with_mask:
            if f['Masks'].__len__() != 0:
                mask_noSkin = f['Masks'][im_num, ...]
                mask_available = [True] * num_of_frames
                # unique_classes, counts = np.unique(mask_noSkin, return_counts=True)
                #
                # # 输出类别及其像素数量
                # print("mask_noSkin 中的类别及其数量：")
                # for cls, count in zip(unique_classes, counts):
                #     print(f"类别 {cls}: {count} 个像素")
                # print(mask_noSkin.shape)
                any_pupil = np.any(mask_noSkin == 3)
                if self.more_feature:
                    any_iris = np.any(mask_noSkin == 2)
                else:
                    any_iris = np.any(mask_noSkin == 3)
                if not (any_pupil and any_iris):
                    # atleast one pixel must belong to all classes
                    mask_noSkin = -np.ones(image.shape[:2])
                    mask_available = [False] * num_of_frames
            else:
                mask_noSkin = -np.ones(image.shape[:2])
                mask_available = [False] * num_of_frames

        # Pupil ellipse parameters
        if f['Fits']['pupil'].__len__() != 0:
            pupil_ellipse_available = [True] * num_of_frames
            pupil_param = f['Fits']['pupil'][im_num, ...]
        else:
            pupil_ellipse_available = [False] * num_of_frames
            pupil_param = -np.ones(num_of_frames, 5, )

        # Iris ellipse parameters
        if f['Fits']['iris'].__len__() != 0:
            iris_ellipse_available = [True] * num_of_frames
            iris_param = f['Fits']['pupil'][im_num, ...]
        else:
            iris_ellipse_available = [False] * num_of_frames
            iris_param = -np.ones(num_of_frames, 5, )

        if f['Gaze_vector'].__len__() != 0:
            gaze_vector_available = [True] * num_of_frames
            gaze_vector = f['Gaze_vector'][im_num, ...]
        else:
            gaze_vector_available = [False] * num_of_frames
            gaze_vector = -np.ones(num_of_frames, 3, )

        if 'Dikablis' or 'LPW' or 'KaleidoEYE' in archName:
            self.more_feature = True

        if self.more_feature:
            if f['Eyeball'].__len__() != 0:
                eyeball_available = [True] * num_of_frames
                eyeball = f['Eyeball'][im_num, ...]
        else:
            eyeball_available = [False] * num_of_frames
            eyeball = -np.ones(num_of_frames, 4, )

        if self.more_feature:
            if f['pupil_lm_2D'].__len__() != 0:
                pupil_lm_2D_available = [True] * num_of_frames
                pupil_lm_2D = f['pupil_lm_2D'][im_num, ...]
        else:
            pupil_lm_2D_available = [False] * num_of_frames
            pupil_lm_2D = -np.ones(num_of_frames, 17, )

        if self.more_feature:
            if f['pupil_lm_3D'].__len__() != 0:
                pupil_lm_3D_available = [True] * num_of_frames
                pupil_lm_3D = f['pupil_lm_3D'][im_num, ...]
        else:
            pupil_lm_3D_available = [False] * num_of_frames
            pupil_lm_3D = -np.ones(num_of_frames, 25, )

        if self.more_feature:
            if f['iris_lm_2D'].__len__() != 0:
                iris_lm_2D_available = [True] * num_of_frames
                iris_lm_2D = f['iris_lm_2D'][im_num, ...]
        else:
            iris_lm_2D_available = [False] * num_of_frames
            iris_lm_2D = -np.ones(num_of_frames, 17, )

        if self.more_feature:
            if f['iris_lm_3D'].__len__() != 0:
                iris_lm_3D_available = [True] * num_of_frames
                iris_lm_3D = f['iris_lm_3D'][im_num, ...]
        else:
            iris_lm_3D_available = [False] * num_of_frames
            iris_lm_3D = -np.ones(num_of_frames, 25, )

        data_dict = {}
        if self.train_with_mask:
            data_dict['mask'] = mask_noSkin
        data_dict['image'] = image
        data_dict['ds_num'] = set_num
        data_dict['pupil_center'] = pupil_center.astype(np.float32)
        data_dict['iris_ellipse'] = iris_param.astype(np.float32)
        data_dict['pupil_ellipse'] = pupil_param.astype(np.float32)
        data_dict['gaze_vector'] = gaze_vector.astype(np.float32)
        if self.more_feature:
            data_dict['eyeball'] = eyeball.astype(np.float32)
            data_dict['pupil_lm_2D'] = pupil_lm_2D.astype(np.float32)
            data_dict['pupil_lm_3D'] = pupil_lm_3D.astype(np.float32)
            data_dict['iris_lm_2D'] = iris_lm_2D.astype(np.float32)
            data_dict['iris_lm_3D'] = iris_lm_3D.astype(np.float32)

        is_bad_list = [False] * num_of_frames

        # Extra check to not return bad batches
        if self.train_with_mask:
            if (np.any(data_dict['mask']<0) or np.any(data_dict['mask']>3)) and not np.all(mask_available):
                # This is a basic sanity check and should never be triggered
                # unless a freak accident caused something to change
                is_bad_list = [True] * num_of_frames

        # Ability to traceback
        data_dict['im_num'] = im_num
        data_dict['archName'] = archName

        # Keep flags as separate entries
        if self.train_with_mask: data_dict['mask_available'] = np.stack(mask_available, axis=0)
        data_dict['pupil_center_available'] = np.stack(pupil_center_available, axis=0) \
            if not np.all(pupil_center == -1) else np.stack([False] * num_of_frames, axis=0)
        data_dict['iris_ellipse_available'] = np.stack(iris_ellipse_available, axis=0)\
            if not np.all(iris_param == -1) else np.stack([False] * num_of_frames, axis=0)
        data_dict['pupil_ellipse_available'] = np.stack(pupil_ellipse_available, axis=0)\
            if not np.all(pupil_param == -1) else np.stack([False] * num_of_frames, axis=0)

        data_dict['is_bad'] = np.stack(is_bad_list, axis=0)
        data_dict['gaze_vector_available'] = np.stack(gaze_vector_available, axis=0)
        if self.more_feature:
            data_dict['eyeball_available'] = np.stack(eyeball_available, axis=0)
            data_dict['pupil_lm_2D_available'] = np.stack(pupil_lm_2D_available, axis=0)
            data_dict['pupil_lm_3D_available'] = np.stack(pupil_lm_3D_available, axis=0)
            data_dict['iris_lm_2D_available'] = np.stack(iris_lm_2D_available, axis=0)
            data_dict['iris_lm_3D_available'] = np.stack(iris_lm_3D_available, axis=0)
        
        return data_dict

def listDatasets(AllDS):
    dataset_list = np.unique(AllDS['dataset'])
    subset_list = np.unique(AllDS['subset'])
    return (dataset_list, subset_list)


# 将多个存档文件中的数据整合成一个统一的字典格式，方便后续的分析和处理。
def readArchives(path2arc_keys):
    D = os.listdir(path2arc_keys)
    AllDS = {'archive': [], 'dataset': [], 'subset': [], 'subject_id': [],
             'im_num': [], 'pupil_loc': [], 'iris_loc': []}

    for chunk in D:
        # Load archive key
        chunkData = scio.loadmat(os.path.join(path2arc_keys, chunk))
        N = np.size(chunkData['archive'])
        pupil_loc = chunkData['pupil_loc']
        subject_id = chunkData['subject_id']

        if not chunkData['subset']:
            print('{} does not have subsets.'.format(chunkData['dataset']))
            chunkData['subset'] = 'none'

        if type(pupil_loc) is list:
            # Replace pupil locations with -1
            print('{} does not have pupil center locations'.format(chunkData['dataset']))
            pupil_loc = -1*np.ones((N, 2))

        if chunkData['Fits']['iris'][0, 0].size == 0:
            # Replace iris locations with -1
            print('{} does not have iris center locations'.format(chunkData['dataset']))
            iris_loc = -1*np.ones((N, 2))
        else:
            iris_loc = chunkData['Fits']['iris'][0, 0][:, :2]

        loc = np.arange(0, N)
        res = np.flip(chunkData['resolution'], axis=1)  # Flip the resolution to [W, H]

        AllDS['im_num'].append(loc)
        AllDS['subset'].append(np.repeat(chunkData['subset'], N))
        AllDS['dataset'].append(np.repeat(chunkData['dataset'], N))
        AllDS['archive'].append(chunkData['archive'].reshape(-1)[loc])
        AllDS['iris_loc'].append(iris_loc[loc, :]/res[loc, :])
        AllDS['pupil_loc'].append(pupil_loc[loc, :]/res[loc, :])
        AllDS['subject_id'].append(subject_id)

    # Concat all entries into one giant list
    for key, val in AllDS.items():
        AllDS[key] = np.concatenate(val, axis=0)
    return AllDS


def rmDataset(AllDS, rmSet):
    '''
    Remove datasets.
    '''
    dsData = copy.deepcopy(AllDS)
    dataset_list = listDatasets(dsData)[0]
    loc = [True if simple_string(ele) in simple_string(rmSet)
           else False for ele in dataset_list]
    rmIdx = np.where(loc)[0]
    for i in rmIdx:
        loc = dsData['dataset'] == dataset_list[i]
        dsData = copy.deepcopy(rmEntries(dsData, loc))
    return dsData


def rmSubset(AllDS, rmSet):
    '''
    Remove subsets.
    '''
    dsData = copy.deepcopy(AllDS)
    dataset_list = listDatasets(dsData)[0]
    loc = [True if simple_string(ele) in simple_string(rmSet)
           else False for ele in dataset_list]
    rmIdx = np.where(loc)[0]
    for i in rmIdx:
        loc = dsData['subset'] == dataset_list[i]
        dsData = copy.deepcopy(rmEntries(dsData, loc))
    return dsData


def selDataset(AllDS, selSet):
    '''
    Select datasets of interest.
    '''
    dsData = copy.deepcopy(AllDS)
    dataset_list = listDatasets(dsData)[0]
    loc = [False if simple_string(ele) in simple_string(selSet)
           else True for ele in dataset_list]
    rmIdx = np.where(loc)[0]
    for i in rmIdx:
        loc = dsData['dataset'] == dataset_list[i]
        dsData = copy.deepcopy(rmEntries(dsData, loc))
    return dsData


def selSubset(AllDS, selSubset):
    '''
    Select subsets of interest.
    '''
    dsData = copy.deepcopy(AllDS)
    subset_list = listDatasets(dsData)[1]
    subset_list_temp = copy.deepcopy(subset_list)
    
    for idx, element in enumerate(subset_list):
        print(element)
        if 'Dikablis' or 'LPW' or 'KaleidoEYE' or 'UnityEyes' in element:
            split = element.split('_')
            subset_list_temp[idx] = split[0] + '_' + split[1]

    loc = [False if simple_string(ele) in simple_string(selSubset)
           else True for ele in subset_list_temp]
    rmIdx = np.where(loc)[0]
    for i in rmIdx:
        loc = dsData['subset'] == subset_list[i]
        dsData = copy.deepcopy(rmEntries(dsData, loc))

    return dsData


def rmEntries(AllDS, ent):
    dsData = copy.deepcopy(AllDS)
    dsData['subject_id'] = AllDS['subject_id'][~ent, ]
    dsData['pupil_loc'] = AllDS['pupil_loc'][~ent, :]
    dsData['iris_loc'] = AllDS['iris_loc'][~ent, :]
    dsData['archive'] = AllDS['archive'][~ent, ]
    dsData['dataset'] = AllDS['dataset'][~ent, ]
    dsData['im_num'] = AllDS['im_num'][~ent, ]
    dsData['subset'] = AllDS['subset'][~ent, ]
    return dsData


def generate_strat_indices(AllDS):
    '''
    删除瞳孔中心值靠近边界10%的图像。
    不移除瞳孔中心为负的图像。
    返回索引和清理的数据记录。
    '''
    # 确定哪些图像的瞳孔中心接近边界的10%，或者瞳孔中心为负
    loc_oBounds = (AllDS['pupil_loc'] < 0.10) | (AllDS['pupil_loc'] > 0.90)  # 是否在边界范围内
    loc_oBounds = np.any(loc_oBounds, axis=1)  # 沿着列方向进行逻辑或运算
    loc_nExist = np.any(AllDS['pupil_loc'] < 0, axis=1)  # 是否存在负值
    loc = loc_oBounds & ~loc_nExist  # 需要移除的图像位置

    # 移除符合条件的图像
    AllDS = rmEntries(AllDS, loc)

    # 获取椭圆中心，如果瞳孔缺失，则使用虹膜中心
    loc_nExist = np.any(AllDS['pupil_loc'] < 0, axis=1)  # 再次检查是否存在负值
    ellipse_centers = AllDS['pupil_loc']
    ellipse_centers[loc_nExist, :] = AllDS['iris_loc'][loc_nExist, :]  # 使用虹膜中心替代缺失的瞳孔中心

    # 生成瞳孔中心的二维直方图
    numBins = 5
    _, edgeList = np.histogramdd(ellipse_centers, bins=numBins)  # 生成直方图
    xEdges, yEdges = edgeList  # 获取边界

    # 获取归档编号
    archNum = np.unique(AllDS['archive'],
                        return_index=True,
                        return_inverse=True)[2]
    # 对瞳孔中心位置进行分箱，并返回该箱的ID
    binx = np.digitize(ellipse_centers[:, 0], xEdges, right=True)
    biny = np.digitize(ellipse_centers[:, 1], yEdges, right=True)
    # 将二维的箱位置转换为索引
    indx = np.ravel_multi_index((binx, biny, archNum),
                                (numBins+1, numBins+1, np.max(archNum)+1))
    indx = indx - np.min(indx)  # 将索引调整为从零开始

    # 移除只占据网格中单个元素的条目
    print('原始条目数量：{}'.format(np.size(binx)))
    countInfo = np.unique(indx, return_counts=True)  # 计算每个箱中的条目数

    for rmInd in np.nditer(countInfo[0][countInfo[1] <= 2]):  # 找到单个元素的箱
        ent = indx == rmInd  # 找到要移除的条目
        indx = indx[~ent]  # 从索引中移除这些条目
        AllDS = copy.deepcopy(rmEntries(AllDS, ent))  # 从数据集中移除这些条目
    print('分层后的条目数量：{}'.format(np.size(indx)))
    return indx, AllDS  # 返回索引和清理后的数据集


# 生成文件列表，并将数据划分为训练集、验证集和测试集。
def generate_fileList(AllDS, mode='vanilla', notest=True):
    # 调用generate_strat_indices函数，生成索引并更新AllDS，移除瞳孔中心靠近边缘的样本
    indx, AllDS = generate_strat_indices(AllDS)
    # 生成subject_identifier列表，每个元素是'archive:subject_id'格式的字符串
    subject_identifier = list(map(lambda x, y:x+':'+y, AllDS['archive'], AllDS['subject_id']))
    # 获取唯一的subject_identifier，并返回每个元素的索引和逆向索引
    archNum = np.unique(subject_identifier, return_index=True, return_inverse=True)[2]
    # 将图像编号、存档编号和索引堆叠成一个二维数组feats
    feats = np.stack([AllDS['im_num'], archNum, indx], axis=1)
    # 设置验证集比例为10%
    validPerc = .10
    if 'vanilla' in mode:
        # 如果模式是vanilla，从选定的数据集中进行划分
        # 根据瞳孔中心和数据集进行分层
        params = re.findall('\d+', mode)
        if len(params) == 1:
            # 如果模式参数有一个值，设置训练集比例
            trainPerc = float(params[0])/100
            print('Training data set to {}%. Validation data set to {}%.'.format(
                        int(100*trainPerc), int(100*validPerc)))
        else:
            # 否则，默认训练集比例为90%
            trainPerc = 1 - validPerc
            print('Training data set to {}%. Validation data set to {}%.'.format(
                        int(100*trainPerc), int(100*validPerc)))
        # 创建数据划分对象data_div
        data_div = Datasplit(1, subject_identifier)
        if not notest:
            # 如果需要划分测试集，将数据划分为训练集和测试集
            train_feats, test_feats = train_test_split(feats,
                                                       train_size = trainPerc,
                                                       stratify = None,
                                                       shuffle=False)
        else:
            # 如果不需要划分测试集，将所有数据作为训练集
            train_feats = feats
            test_feats = []
        # 将训练集进一步划分为验证集，使用shuffle进行数据集打乱
        """ train_feats, valid_feats = train_test_split(train_feats,
                                                    test_size = 0.2,
                                                    random_state = None,
                                                    stratify = train_feats[:,-1]) """

        # 将训练集进一步划分为验证集，比例为10%，不打乱数据
        train_feats, valid_feats = train_test_split(train_feats,
                                                    test_size = 0.1,
                                                    shuffle=False)
        # 将划分后的索引分配给data_div
        data_div.assignIdx(0, train_feats, valid_feats, test_feats)
    if 'fold' in mode:
        # 如果模式是fold，进行K折交叉验证
        K = int(re.findall('\d+', mode)[0])

        # 创建数据划分对象data_div
        data_div = Datasplit(K, subject_identifier)
        skf = StratifiedKFold(n_splits=K, shuffle=True)
        train_feats, test_feats = train_test_split(feats,
                                                   train_size = 1 - validPerc,
                                                   stratify = indx)
        i=0
        # 进行K折交叉验证，将结果分配给data_div
        for train_loc, valid_loc in skf.split(train_feats, train_feats[:, -1]):
            data_div.assignIdx(i, train_feats[train_loc, :],
                               train_feats[valid_loc, :],
                               test_feats)
            i+=1

    if 'none' in mode:
        # 如果模式是none，不进行任何划分，将所有图像放入训练、验证和测试集中
        # 这种方式确保没有混淆
        data_div = Datasplit(1, subject_identifier)
        data_div.assignIdx(0, feats, feats, feats)

    # 返回数据划分对象data_div
    return data_div



def generateIdx(samplesList, batch_size):
    '''
    接受2D数组<样本列表>
    样本列表:第一维图像编号
    样本列表:第二维hf5文件号
    batch_size:一批中要出现的图像数量
    如果没有找到条目，generateIdx将返回一个空的批处理列表
    '''
    if np.size(samplesList) > 0:
        num_samples = samplesList.shape[0]
        num_batches = np.ceil(num_samples/batch_size).astype(np.int)
        np.random.shuffle(samplesList) # random.shuffle works on the first axis
        batchIdx_list = []
        for i in range(0, num_batches):
            y = (i+1)*batch_size if (i+1)*batch_size<num_samples else num_samples
            batchIdx_list.append(samplesList[i*batch_size:y, :])
    else:
        batchIdx_list = []
    return batchIdx_list

def foldInfo():
    D = {'train_idx': [], 'valid_idx': [], 'test_idx': []}
    return D

class Datasplit():
    # 管理数据集的不同划分（如训练集、验证集和测试集），并确保每个划分中的图像编号是唯一的
    def __init__(self, K, archs):
        # 初始化Datasplit类，K为折数，archs为数据集的唯一标识
        self.splits = K  # 保存折数
        # 创建包含K个foldInfo实例的列表，每个实例表示一个折的信息
        self.folds = [foldInfo() for i in range(0, self.splits)]
        # 获取archs中的唯一值并保存
        self.arch = np.unique(archs)

    def assignIdx(self, foldNum, train_idx, valid_idx, test_idx):
        # 分配索引到相应的折中，foldNum表示当前的折数
        # train_idx, valid_idx和test_idx包含图像编号、h5文件和分层索引

        # 检查训练集索引的唯一性
        self.checkUnique(train_idx)
        # 检查验证集索引的唯一性
        self.checkUnique(valid_idx)
        # 检查测试集索引的唯一性
        self.checkUnique(test_idx)

        # 分配训练集索引，如果train_idx不是列表，则取其前两列（图像编号和h5文件）
        self.folds[foldNum]['train_idx'] = train_idx[:, :2] if type(train_idx) is not list else []
        # 分配验证集索引，如果valid_idx不是列表，则取其前两列
        self.folds[foldNum]['valid_idx'] = valid_idx[:, :2] if type(valid_idx) is not list else []
        # 分配测试集索引，如果test_idx不是列表，则取其前两列
        self.folds[foldNum]['test_idx'] = test_idx[:, :2] if type(test_idx) is not list else []

    def checkUnique(self, ID):
        # 检查ID中的图像编号是否唯一
        if type(ID) is not list:
            imNums = ID[:, 0]  # 提取图像编号
            chunks = ID[:, 1]  # 提取h5文件编号
            # 获取唯一的h5文件编号
            chunks_present = np.unique(chunks)
            for chunk in chunks_present:
                loc = chunks == chunk  # 找到当前h5文件编号的所有位置
                # 检查图像编号在当前h5文件中的唯一性
                unq_flg = np.size(np.unique(imNums[loc])) != np.size(imNums[loc])
                if unq_flg:
                    # 如果不唯一，打印警告
                    print('Not unique! WARNING')


if __name__=="__main__":
    # 该脚本验证所有数据集并返回图像总数
    # 运行sandbox.py以验证dataloader。
    path2data = r'D:\Xiao\DataSet\TEyeD'
    path2arc_keys = os.path.join(path2data, 'MasterKey')

    AllDS = readArchives(path2arc_keys)
    datasets_present, subsets_present = listDatasets(AllDS)

    print('Datasets selected ---------')
    print(datasets_present)
    print('Subsets selected ---------')
    print(subsets_present)

    dataDiv_Obj = generate_fileList(AllDS, mode='vanilla')
    N = [value.shape[0] for key, value in dataDiv_Obj.folds[0].items() if len(value) > 0]
    print('Test Total number of images: {}'.format(np.sum(N)))

    with open('CurCheck.pkl', 'wb') as fid:
        pickle.dump(dataDiv_Obj, fid)
