import pickle
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.ticker import FuncFormatter

def scientific_format(x, pos):
    return f'{x:.1E}'

formatter = FuncFormatter(scientific_format)

if __name__ == '__main__':
    data_path = 'work_statistics/retinanet_r18_fpn_1x_coco/save-act_100_fuse_True.npz'
    target_mode = 'out'
    with open(data_path, 'rb') as tarfile:
        loaded_data_dict = pickle.load(tarfile)
    # 删除一些条目
    keys_for_deletion = []
    for name in loaded_data_dict.keys():
        if target_mode not in name:
            keys_for_deletion.append(name)
    for key in keys_for_deletion:
        # 则删除该条目
        del loaded_data_dict[key]
    
    # 获取histogram的minmax
    global_min = 0
    global_max = 0
    num_iters = 0
    for datas in loaded_data_dict.values():
        num_iters = len(datas)
        for data in datas:
            tmin = data.min()
            tmax = data.max()
            if tmin < global_min:
                global_min = tmin
            if tmax > global_max:
                global_max = tmax
    
    print('非法改装min/max！！')
    global_max = 3
    global_min = -3
    
    # 预定义
    colors = [(0, '#352A88'), (0.1, '#352A88'), (0.6, '#51BA8D'), (1, '#F9AF2B')]  # 定义从蓝色到黄色再到绿色的渐变
    cmap = LinearSegmentedColormap.from_list('custom_cmap', colors)
    
    bin_num = 200
    magnitudes_bins = np.linspace(global_min, global_max, bin_num)

    magnitudes = (magnitudes_bins[:-1] + magnitudes_bins[1:]) / 2
    iters = np.arange(num_iters)
    
    # 创建网格
    X, Y = np.meshgrid(magnitudes, iters)
    
    # 初始化3D图
    fig = plt.figure(figsize=(16,6.5))
    plt.subplots_adjust(left=0.01,bottom=0.01,right=0.99,top=0.99,wspace=-0.4,hspace=.0) # 设置子图间距。负的确实有效
    
    for i, name in enumerate(loaded_data_dict.keys()):
        ax = fig.add_subplot(1, 2, i+1, projection='3d')
        
        data = loaded_data_dict[name]
        
        # 计算直方图
        hist_data = np.array([np.histogram(d, bins=magnitudes_bins)[0] for d in data])
        Z = hist_data

        # 绘制曲面图
        surf = ax.plot_surface(X, Y, Z, cmap=cmap, rstride = 1, cstride = 1, linewidth=0, edgecolor='black', antialiased=True)
        
        # 计算并绘制1%和99%的位置
        percentiles_1 = []
        percentiles_99 = []
        for d in data:
            p1 = np.percentile(d, 1)
            p99 = np.percentile(d, 99)
            percentiles_1.append(p1)
            percentiles_99.append(p99)
        
        ax.scatter(percentiles_1, iters, [0] * num_iters, color='red', s=10, label='1% percentile')
        ax.scatter(percentiles_99, iters, [0] * num_iters, color='blue', s=10, label='99% percentile')

        # 设置标签
        ax.set_xlabel('magnitude')
        ax.set_ylabel('iter')
        ax.yaxis.label.set_rotation(90)  # 设置y轴标签为竖直显示
        ax.set_zlabel('num', labelpad=32)
        ax.zaxis.set_major_formatter(formatter)
        ax.tick_params(axis='z', pad=18)  # 设置Z轴刻度标签到轴的距离
        ax.zaxis.label.set_rotation(90)  # 设置Z轴标签为竖直显示
        
        ax.set_xlim(global_max, global_min)
        ax.set_ylim(num_iters, 0)
        # 调整视角，使 magnitudes 面对视线
        # ax.view_init(elev=20, azim=85)
        ax.view_init(elev=25, azim=87)
        
        ax.legend()

    # 显示图形
    plt.show()

    plt.savefig('y_function_vis/tmp4_with_percentiles.pdf')
