#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This started as a copy of https://bitbucket.org/RSKothari/multiset_gaze/src/master/ 
with additional changes and modifications to adjust it to our implementation. 

Copyright (c) 2021 Rakshit Kothari, Aayush Chaudhary, Reynold Bailey, Jeff Pelz, 
and Gabriel Diaz
"""

import os
import cv2
import sys
import argparse

import numpy as np
import pandas as pd
import multiprocessing as mp

import deepdish as dd
import scipy.io as scio

from skimage import draw

sys.path.append('..')
from helperfunctions.hfunctions import plot_segmap_ellpreds
from helperfunctions.hfunctions import generateEmptyStorage

from Visualitation_TEyeD.gaze_estimation import draw_gaze, draw_landmark

def make_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--path2ds', type=str, default=r'D:\Xiao\DataSet\SRGaze',
                        help='path to datasets')
    parser.add_argument('--path_data', type=str, default=r'D:\Xiao\DataSet\SRGaze\Dikablis',
                        help='path to TEyeD Dikablis eye videos')
    args = parser.parse_args()
    return args


def process_entry(args,  num_shuffles=50):
    vid_name_ext = args['vid_name']
    PATH_DS = os.path.join(args['path2ds'], 'All_SR')
    PATH_MASTER = os.path.join(args['path2ds'], 'MasterKey_SR')
    ds_name = '{}'.format(vid_name_ext)


    # 读取标注文件
    pupil_ellipses = pd.read_csv(args['path_annot'] + 'pupil_eli.txt',
                                 on_bad_lines="skip",
                                 delimiter=';').to_numpy()

    eye_ball = pd.read_csv(args['path_annot'] + 'eye_ball.txt',
                           on_bad_lines="skip",
                           delimiter=';').to_numpy()

    gaze_vector = pd.read_csv(args['path_annot'] + 'gaze_vec.txt',
                              on_bad_lines="skip",
                              delimiter=';').to_numpy()

    # 清理数据，去除NaN和帧编号
    valid_frames = pupil_ellipses[:, 0].astype(int)  # Extract valid frame indices
    pupil_ellipses = pupil_ellipses[..., 1:-1]
    eye_ball = eye_ball[..., 1:-1]
    gaze_vector = gaze_vector[..., 1:-1]



    # 初始化存储
    Data = {
        'Images': [],
        'dataset': 'SRGaze',
        'subset': '{}'.format(vid_name_ext),
        'resolution': [],
        'archive': [],
        'Info': [],
        'Masks': [],
        'subject_id': [],
        'Fits': {'pupil': [], 'iris': []},
        'pupil_loc': [],
        'Eyeball': [],
        'Gaze_vector': [],
        'timestamp': []
    }

    keydict = {
        'dataset': 'SRGaze',
        'subset': '{}'.format(vid_name_ext),
        'resolution': [],
        'subject_id': [],
        'archive': [],
        'Info': [],
        'Fits': {'pupil': [], 'iris': []},
        'pupil_loc': []
    }
    # Calculate scaling factors
    scale_x = 320 / 400
    scale_y = 240 / 400
    # 逐帧读取图片
    img_folder = args['path_images']
    frame_files = sorted([f for f in os.listdir(img_folder) if f.endswith('.jpg')])
    total_frames = len(frame_files)
    for shuffle_idx in range(num_shuffles):
        print(f"Shuffle iteration {shuffle_idx + 1}/{num_shuffles} for {vid_name_ext}")
        # 随机打乱帧顺序
        shuffled_indices = np.random.permutation(len(valid_frames))
        shuffled_frames = valid_frames[shuffled_indices]
        shuffled_pupil_ellipses = pupil_ellipses[shuffled_indices]
        shuffled_eye_ball = eye_ball[shuffled_indices]
        shuffled_gaze_vector = gaze_vector[shuffled_indices]

        for i, fr_idx in enumerate(shuffled_frames):
            frame_path = os.path.join(img_folder, f"frame_{fr_idx:04d}.jpg")
            if not os.path.exists(frame_path):
                print(f"Frame not found: {frame_path}")
                continue

            frame = cv2.imread(frame_path, cv2.IMREAD_GRAYSCALE)
            if frame is None:
                print(f"Frame could not be read: {frame_path}")
                continue

            # scale_factor = 320 / frame.shape[1]
            # eye_ball[i] = eye_ball[i] * scale_factor
            eye_ball_params = shuffled_eye_ball[i]
            radius, x, y, z = eye_ball_params

            # Scale parameters
            radius_new = radius * scale_y  # Adjust radius
            x_new = x * scale_x  # Adjust x-coordinate
            y_new = y * scale_y  # Adjust y-coordinate
            z_new = z  # z-coordinate remains unchanged

            # Update the eyeball parameters
            # eye_ball[i] = [radius_new, x_new, y_new, z_new]
            eye_ball_params = [radius_new, x_new, y_new, z_new]

            frame = cv2.resize(frame, (320, 240), interpolation=cv2.INTER_LANCZOS4)
            imName_Full = 'SRGaze-{}-frame-{}-shuffle-{}'.format(vid_name_ext, fr_idx, shuffle_idx)

            model_pupil = shuffled_pupil_ellipses[i]
            angle, center_x, center_y, width, height = model_pupil

            # Scale parameters
            center_x_new = center_x * scale_x  # Adjust X center
            center_y_new = center_y * scale_y  # Adjust Y center
            width_new = width * scale_x  # Adjust width
            height_new = height * scale_y  # Adjust height
            angle_new = angle  # Angle remains unchanged

            # Update the ellipse parameters
            model_pupil = [angle_new, center_x_new, center_y_new, width_new, height_new]
            # pupil_ellipses[i] = [angle_new, center_x_new, center_y_new, width_new, height_new]
            # model_pupil = pupil_ellipses[i]

            model_pupil = np.roll(model_pupil, shift=-1)
            model_pupil[2:4] = model_pupil[2:4] / 2
            model_pupil[-1] = np.deg2rad(model_pupil[-1] - 90)
            model_pupil[[2, 3]] = model_pupil[[3, 2]]
            pupil_loc = model_pupil[:2]

            timestamp = (1 / 25) * fr_idx

            [rr_p, cc_p] = draw.ellipse(round(model_pupil[1]),
                                        round(model_pupil[0]),
                                        round(model_pupil[3]),
                                        round(model_pupil[2]),
                                        shape=(frame.shape[0], frame.shape[1]),
                                        rotation=-model_pupil[4])

            LabelMat = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.int32)
            LabelMat[rr_p, cc_p] = 3
            LabelMat = cv2.resize(LabelMat, (320, 240), interpolation=cv2.INTER_NEAREST)

            # Append data
            Data['Info'].append(imName_Full)
            Data['Masks'].append(LabelMat)
            Data['Images'].append(frame)
            Data['pupil_loc'].append(pupil_loc)
            Data['subject_id'].append('0')
            Data['Eyeball'].append(eye_ball_params)
            Data['Gaze_vector'].append(shuffled_gaze_vector[i])
            Data['timestamp'].append(timestamp)
            Data['Fits']['pupil'].append(model_pupil)

            keydict['archive'].append(ds_name)
            keydict['resolution'].append(frame.shape)
            keydict['pupil_loc'].append(pupil_loc)
            keydict['subject_id'].append('0')
            keydict['Fits']['pupil'].append(model_pupil)

    # 堆叠数据
    Data['Masks'] = np.stack(Data['Masks'], axis=0)
    Data['Images'] = np.stack(Data['Images'], axis=0)
    Data['pupil_loc'] = np.stack(Data['pupil_loc'], axis=0)
    Data['subject_id'] = np.stack(Data['subject_id'], axis=0)
    Data['Masks_noSkin'] = Data['Masks']
    Data['Fits']['pupil'] = np.stack(Data['Fits']['pupil'], axis=0)
    Data['Eyeball'] = np.stack(Data['Eyeball'], axis=0)
    Data['Gaze_vector'] = np.stack(Data['Gaze_vector'], axis=0)
    Data['timestamp'] = np.stack(Data['timestamp'], axis=0)

    keydict['resolution'] = np.stack(keydict['resolution'], axis=0)
    keydict['subject_id'] = np.stack(keydict['subject_id'], axis=0)
    keydict['pupil_loc'] = np.stack(keydict['pupil_loc'], axis=0)
    keydict['archive'] = np.stack(keydict['archive'], axis=0)
    keydict['Fits']['pupil'] = np.stack(keydict['Fits']['pupil'], axis=0)
    print(f"Processed frames: {len(valid_frames) * num_shuffles}")
    dd.io.save(os.path.join(PATH_DS, ds_name + '.h5'), Data)
    scio.savemat(os.path.join(PATH_MASTER, str(ds_name) + '.mat'), keydict, appendmat=True)


if __name__ == '__main__':

    args = vars(make_args())
    path_videos = os.path.join(args['path_data'], 'IMAGES')
    path_annots = os.path.join(args['path_data'], 'ANNOTATIONS')
    list_folders = os.listdir(path_videos)

    num_of_folders = 0

    for folder_name in list_folders:
        if 'KaleidoEYE' in folder_name:
            args['vid_name'] = os.path.splitext(folder_name)[0]
            args['path_images'] = os.path.join(path_videos, folder_name)
            args['path_annot'] = os.path.join(path_annots, folder_name)
            print(folder_name)
            print('Processing folder ({}/{})'.format(num_of_folders + 1, len(list_folders)))

            process_entry(args,  num_shuffles=50)
            num_of_folders += 1

    print('Total folders processed: {}'.format(num_of_folders))





