# -*- coding: UTF-8 -*-
'''
@Project ：Model-aware_3D_Eye_Gaze 
@File    ：testpara.py
@Author  ：xyf
@Date    ：2024/10/22 13:15 
'''
import time
from fvcore.nn import FlopCountAnalysis, flop_count_table
import torch
import torch.nn as nn
import torch.nn.functional as F
from args_maker import make_args
from helperfunctions.utils import get_nparams
from models.resnet.res_50_3 import res_50_3

args = vars(make_args())
norm = nn.BatchNorm2d
model = res_50_3(args,norm=norm,act_func=F.leaky_relu)

count = 0
print(get_nparams(model))
for i in range(1):
    # 示例用法
    data_dict = {}
    data_dict["image"] = torch.randn(1, 4, 240, 320)  # 输入眼图
    model.eval()
    start_network = time.time()
    out_dict = model(data_dict, args)
    # print(output.shape)  # 输出结果
    # print(eyeFeature.shape)  # 输出结果
    # print(eventFeature.shape)  # 输出结果
    end_time = time.time() - start_network
    count += end_time
    print(end_time, "seconds elapsed")
    # # 计算每个块的参数量
    # params = count_parameters(model)
    # sum = 0
    # for block, num_params in params.items():
    #     sum += num_params
    #     print(f"{block}: {num_params} parameters")
    # print(f"总参数量：{sum}")
    # # 遍历模型的所有参数并检查数据类型：
    # for param in model.parameters():
    #     print("Parameter data type:", param.dtype)

    # 计算模型 FLOPs
    flops = FlopCountAnalysis(model, (data_dict, args))
    print(flop_count_table(flops))

# print("Average time:", count / 1)  # Averagetime: 0.07276225090026856