import pickle
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.ticker import FuncFormatter
import os



if __name__ == '__main__':
    # data_path = 'work_statistics/retinanet_r18_fpn_1x_coco/save-act_100_fuse_True.npz'
    data_path = 'work_statistics/retinanet_r18_fpn_1x_coco_QFOD/save-act_100_fuse_True.npz'
    # target_mode = 'out'
    target_mode = 'out'
    with open(data_path, 'rb') as tarfile:
        loaded_data_dict = pickle.load(tarfile)
    # 删除一些条目
    keys_for_deletion = []
    for name in loaded_data_dict.keys():
        if target_mode not in name:
            keys_for_deletion.append(name)
    for key in keys_for_deletion:
        # 则删除该条目
        del loaded_data_dict[key]
    
    # 获取histogram的minmax
    global_min = 0
    global_max = 0
    num_iters = 0
    for datas in loaded_data_dict.values():
        num_iters = len(datas)
        for data in datas:
            tmin = data.min()
            tmax = data.max()
            if tmin < global_min:
                global_min = tmin
            if tmax > global_max:
                global_max = tmax
    
    print('非法改装min/max！！')
    global_max = 6
    global_min = -3
    # ----预定义
    # 定义自定义颜色映射
    # colors = [(0, '#352A88'), (0.5, 'yellow'), (1, 'green')]  # 定义从蓝色到黄色再到绿色的渐变
    colors = [(0, '#123EA0'), (0.1, '#123EA0'), (0.6, '#51BA8D'), (1, '#F9AF2B')]  # 定义从蓝色到黄色再到绿色的渐变
    cmap = LinearSegmentedColormap.from_list('custom_cmap', colors)
    def scientific_format(x, pos):
        return f'{x:.1E}'

    formatter = FuncFormatter(scientific_format)
    
    
    
    
    bin_num = 200
    magnitudes_bins = np.linspace(global_min, global_max, bin_num)

    magnitudes = (magnitudes_bins[:-1] + magnitudes_bins[1:]) / 2
    iters = np.arange(num_iters)
    
    # 创建网格
    X, Y = np.meshgrid(magnitudes, iters)
    
    # 初始化3D图
    fig = plt.figure(figsize=(16,6.5))
    plt.subplots_adjust(left=0.01,bottom=0.01,right=0.99,top=0.99,wspace=-0.2,hspace=.0) # 设置子图间距。负的确实有效
    
    for i, name in enumerate(loaded_data_dict.keys()):
        ax = fig.add_subplot(1, 2, i+1, projection='3d')
        
        data = loaded_data_dict[name]
        
        # 计算直方图
        hist_data = np.array([np.histogram(d, bins=magnitudes_bins)[0] for d in data])
        Z = hist_data

        # 绘制曲面图
        surf = ax.plot_surface(X, Y, Z, cmap=cmap, rstride = 1, cstride = 1, linewidth=0, edgecolor='black', antialiased=True, alpha=0.6)
        
        # 计算并绘制1%和99%的位置
        percentiles_1 = []
        percentiles_99 = []
        for d in data:
            p1 = np.percentile(d, 0.1)
            p99 = np.percentile(d, 99.9)
            percentiles_1.append(p1)
            percentiles_99.append(p99)

        # 标记1%和99%的位置
        Z_min = np.min(Z)
        Z_max = np.max(Z)
        
        for j, (p1, p99) in enumerate(zip(percentiles_1, percentiles_99)):
            bin_index_1 = np.digitize(p1, magnitudes_bins) - 1
            bin_index_99 = np.digitize(p99, magnitudes_bins) - 1
            bin_index_1 = max(0, min(bin_index_1, Z.shape[1] - 1))
            bin_index_99 = max(0, min(bin_index_99, Z.shape[1] - 1))
            Z1 = Z[j, bin_index_1] if bin_index_1 >= 0 and bin_index_1 < Z.shape[1] else Z_min
            Z99 = Z[j, bin_index_99] if bin_index_99 >= 0 and bin_index_99 < Z.shape[1] else Z_max
            ax.scatter(p1, j, Z1, color='red', s=14, alpha = 1, zorder=100, label='1% percentile' if j == 0 else "")
            ax.scatter(p99, j, Z99, color='black', s=14, alpha = 1, zorder=100, label='99% percentile' if j == 0 else "")
        
        
        # 设置标签
        ax.set_xlabel('magnitude', fontsize=16)
        ax.set_ylabel('iter', fontsize=16)
        ax.yaxis.label.set_rotation(90)  # 设置y轴标签为竖直显示
        ax.set_zlabel('num', labelpad=36, fontsize=16)
        ax.zaxis.set_major_formatter(formatter)
        ax.tick_params(axis='x', labelsize=12)
        ax.tick_params(axis='y', labelsize=12)
        ax.tick_params(axis='z', labelsize=12, pad=18)  # 设置Z轴刻度标签到轴的距离
        ax.zaxis.label.set_rotation(90)  # 设置Z轴标签为竖直显示
        
        
        ax.set_xlim(global_max, global_min)
        ax.set_ylim(num_iters, 0)
        # 调整视角，使 magnitudes 面对视线
        # ax.view_init(elev=20, azim=85)
        ax.view_init(elev=25, azim=85)


    # 显示图形
    plt.show()
    new_file_name = f'K{num_iters}_act_{target_mode}.pdf'
    dir_name = os.path.dirname(data_path)
    # 构建新的文件路径
    new_path = os.path.join(dir_name, new_file_name)
    plt.savefig(new_path)
            
    pass