# Copyright (c) Prophesee S.A.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and limitations under the License.


import os
import sys
import datetime
import numpy as np


EV_TYPE = [('t', 'u4'), ('_', 'i4')]  # Event2D
EV_STRING = 'Event2D'

######################################################################################
# dat_events_tools
######################################################################################

def load_td_data(filename, ev_count=-1, ev_start=0):
    """
    Loads TD data from files generated by the StreamLogger consumer for Event2D
    events [ts,x,y,p]. The type ID in the file header must be 0.
    args :
        - path to a dat file
        - number of event (all if set to the default -1)
        - index of the first event

    return :
        - dat, a dictionary like structure containing the fields ts, x, y, p
    """

    with open(filename, 'rb') as f:
        _, ev_type, ev_size, _ = parse_header(f)
        if ev_start > 0:
            f.seek(ev_start * ev_size, 1)

        dtype = EV_TYPE
        dat = np.fromfile(f, dtype=dtype, count=ev_count)
        xyp = None
        if ('_', 'i4') in dtype:
            x = np.bitwise_and(dat["_"], 16383)
            y = np.right_shift(
                np.bitwise_and(dat["_"], 268419072), 14)
            p = np.right_shift(np.bitwise_and(dat["_"], 268435456), 28)
            xyp = (x, y, p)
        return _dat_transfer(dat, dtype, xyp=xyp)


def _dat_transfer(dat, dtype, xyp=None):
    """
    Transfers the fields present in dtype from an old datastructure to a new datastructure
    xyp should be passed as a tuple
    args :
        - dat vector as directly read from file
        - dtype _numpy dtype_ as a list of couple of field name/ type eg [('x','i4'), ('y','f2')]
        - xyp optional tuple containing x,y,p extracted from a field '_'and untangled by bitshift and masking
    """
    variables = []
    xyp_index = -1
    for i, (name, _) in enumerate(dtype):
        if name == '_':
            xyp_index = i
            continue
        variables.append((name, dat[name]))
    if xyp and xyp_index == -1:
        print("Error dat didn't contain a '_' field !")
        return
    if xyp_index >= 0:
        dtype = dtype[:xyp_index] + [('x', 'i2'), ('y', 'i2'), ('p', 'i2')] + dtype[xyp_index + 1:]
    new_dat = np.empty(dat.shape[0], dtype=dtype)
    if xyp:
        new_dat["x"] = xyp[0].astype(np.uint16)
        new_dat["y"] = xyp[1].astype(np.uint16)
        new_dat["p"] = xyp[2].astype(np.uint16)
    for (name, arr) in variables:
        new_dat[name] = arr
    return new_dat

def stream_td_data(file_handle, buffer, dtype, ev_count=-1):
    """
    从打开的文件句柄中流式传输数据
    参数：
        - file_handle：文件对象
        - buffer：预先分配的缓冲区，用于存储事件
        - dtype：期望的字段
        - ev_count：事件数
    """
    # 从文件句柄中读取数据
    dat = np.fromfile(file_handle, dtype=dtype, count=ev_count) # 从文件句柄中读取数据，并将数据存储到 NumPy 数组 dat 中
    count = len(dat['t']) # 获取读取到的事件数量并赋值给 count

    # 将数据存储到缓冲区中
    for name, _ in dtype:
        if name == '_':
            # 解析x、y和p字段
            buffer['x'][:count] = np.bitwise_and(dat["_"], 16383) # 16383 的二进制表示为 11111111111111，即取 dat["_"] 的低 14 位。提取dat["_"]的低14位，即提取x坐标。
            buffer['y'][:count] = np.right_shift(np.bitwise_and(dat["_"], 268419072), 14) # 将 dat["_"] 和 268419072(即二进制数 11111111111111110000000000000000)进行按位与运算，然后再右移 14 位。提取dat["_"]的第29位到第42位，即提取y坐标。
            buffer['p'][:count] = np.right_shift(np.bitwise_and(dat["_"], 268435456), 28)  # 268435456 的二进制表示为 1 0000 0000 0000 0000 0000 0000 0000。然后将结果右移28位，提取dat["_"]的高位，即提取t。
        else:
            # 解析其他字段
            buffer[name][:count] = dat[name] # 将读取到的数据存储到缓冲区中

def count_events(filename):
    """
    Returns the number of events in a dat file
    args :
        - path to a dat file
    """
    with open(filename, 'rb') as f:
        bod, _, ev_size, _ = parse_header(f)
        f.seek(0, os.SEEK_END)
        eod = f.tell()
        if (eod - bod) % ev_size != 0:
            raise Exception("unexpected format !")
        return (eod - bod) // ev_size

def parse_header(f):
    """
    解析 dat 文件的头部信息
    参数：
        - f: 一个指向 dat 文件的文件句柄
    返回值：
        - int: 文件游标在头部信息之后的位置
        - int: 事件类型
        - int: 事件大小(以字节为单位)
        - size (height, width) tuple of int or None: 高度和宽度的整数元组或者 None
    """
    f.seek(0, os.SEEK_SET)  # 将文件指针移动到文件开头
    bod = None  # 用于记录当前文件指针位置的变量
    end_of_header = False  # 用于判断是否已经读取完头部信息的标志位
    header = []  # 用于存储头部信息的列表
    num_comment_line = 0  # 用于记录有多少行处理完成了
    size = [None, None]  # 用于存储高度和宽度的列表，初始值为 [None, None]

    # 解析头部信息
    while not end_of_header:
        bod = f.tell()  # 获取当前文件指针位置并赋值给 bod
        line = f.readline()  # 读取一行数据并赋值给 line
        if sys.version_info > (3, 0): # 判断 Python 版本是否大于 3.0
            first_item = line.decode("latin-1")[:2] # 将 line 解码为 latin-1 编码，并取前两个字符
        else:
            first_item = line[:2] # 如果 Python 版本小于等于 3.0，则不需要解码，直接取前两个字符

        if first_item != '% ': # 如果前两个字符不是 "% "，则表示读取完了头部信息
            end_of_header = True
        else: # 否则，表示还没有读取完头部信息
            words = line.split() # 将 line 按空格分割，并将分割后的字符串列表赋值给 words
            if len(words) > 1:
                if words[1] == 'Date':
                    header += ['Date', words[2] + ' ' + words[3]]
                if words[1] == 'Height' or words[1] == b'Height':  # 对 Python 3(以及 Python2)兼容性处理
                    size[0] = int(words[2])
                    header += ['Height', words[2]]
                if words[1] == 'Width' or words[1] == b'Width':  # 对 Python 3(以及 Python2)兼容性处理
                    size[1] = int(words[2])
                    header += ['Width', words[2]]
            else:
                header += words[1:3]
            num_comment_line += 1 # 更新已处理行数

    f.seek(bod, os.SEEK_SET)  # 将文件指针移回原来的位置

    '''确保与以前的文件兼容性。'''
    if num_comment_line > 0:
        # 读取事件类型。将一个字节的数据转换成无符号整数类型。
        ev_type = np.frombuffer(f.read(1), dtype=np.uint8)[0]
        # 读取事件大小。将一个字节的数据转换成无符号整数类型。
        ev_size = np.frombuffer(f.read(1), dtype=np.uint8)[0]
    else:
        ev_type = 0
        ev_size = sum([int(n[-1]) for _, n in EV_TYPE])
        # int(n[-1]) 表示将子列表的最后一个字符转换成整数。整个列表推导式的含义就是将 EV_TYPE 中每个子列表的最后一个字符转换成整数，并将这些整数组成一个新的列表。
        # 例如，如果 EV_TYPE 是 [['a', '3'], ['b', '5'], ['c', '7']],那么这个列表推导式就会生成 [3, 5, 7] 这个列表。

    bod = f.tell() # 获取当前文件指针位置并赋值给 bod
    return bod, ev_type, ev_size, size # 返回文件游标在头部信息之后的位置、事件类型、事件大小和高度、宽度

def write_header(filename, height=240, width=320, ev_type=0):
    """
    write header for a dat file
    """
    # 首先检查给定的高度和宽度是否超过了 .dat 文件格式中允许的最大范围
    if max(height, width) > 2**14 - 1:
        raise ValueError('Coordinates value exceed maximum range in'
                         ' binary .dat file format max({:d},{:d}) vs 2^14 - 1'.format(
                             height, width))
    # 然后，打开文件句柄 f，以写入模式打开指定的文件
    f = open(filename, 'w')

    # 接下来，使用 f.write 函数写入头部信息。
    # 首先写入一行注释，指明数据文件包含的事件类型。然后写入版本信息。
    # 接着，获取当前的日期和时间，并将其写入文件中。
    # 最后，写入高度和宽度。
    f.write('% Data file containing {:s} events.\n'
            '% Version 2\n'.format(EV_STRING[ev_type]))
    now = datetime.datetime.utcnow()
    f.write("% Date {}-{}-{} {}:{}:{}\n".format(now.year,
                                                now.month, now.day, now.hour,
                                                now.minute, now.second))

    f.write('% Height {:d}\n'
            '% Width {:d}\n'.format(height, width))
    
    # 然后，计算事件类型的大小，并将事件类型和大小以字节为单位的形式写入文件
    ev_size = sum([int(b[-1]) for _, b in EV_TYPE])
    np.array([ev_type, ev_size], dtype=np.uint8).tofile(f)
    # 最后，使用 f.flush() 函数将缓冲区中的数据刷新到文件中，并返回文件句柄 f。
    f.flush()
    return f

# 名为 write_event_buffer 的函数，用于将事件数据写入文件对象 f 中
def write_event_buffer(f, buffers): # 该函数接受一个文件对象 f 和一个包含事件数据的字典 buffers 作为参数。
    """
    writes events of fields x,y,p,t into the file object f
    """
    # pack data as events
    dtype = EV_TYPE
    data_to_write = np.empty(len(buffers['t']), dtype=dtype)

    for (name, typ) in buffers.dtype.fields.items():
        if name == 'x':
            x = buffers['x'].astype('i4') # 对于字段名为 'x'，将其转换为有符号的 32 位整数类型
        elif name == 'y':
            y = np.left_shift(buffers['y'].astype('i4'), 14) # 对于字段名为 'y'，将其左移 14 位，并转换为有符号的 32 位整数类型
        elif name == 'p':
            buffers['p'] = (buffers['p'] == 1).astype(buffers['p'].dtype)
            p = np.left_shift(buffers['p'].astype("i4"), 28) # 对于字段名为 'p'，将其转换为布尔类型，并将其左移 28 位，并转换为有符号的 32 位整数类型
        else:
            data_to_write[name] = buffers[name].astype(typ[0]) # 对于其他字段名，将其转换为对应的数据类型

    # 对于字段名为 '_'，将经过处理后的 'x'、'y' 和 'p' 字段相加，并存储到 data_to_write['_'] 中
    data_to_write['_'] = x + y + p

    # write data
    data_to_write.tofile(f)
    f.flush() # 使用 data_to_write.tofile(f) 将数据写入文件对象 f 中，并使用 f.flush() 刷新缓冲区，确保数据被写入文件
