import json
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import os



def split_list(lst, n):
    """
    将列表 lst 等分为 n 个子列表
    """
    k, m = divmod(len(lst), n)
    return [lst[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n)]


def filter_data_by_mode(data, struct_names, filter_mode='in', func_mode='std_mean'):
    # filtered_data = {
    #     struct_names[0]: [],
    #     struct_names[1]: [],
    #     struct_names[2]: [[],[]]  # 第一个cls 第二个reg
    # }
    filtered_data = {
        struct_names[0]: [],  # 第一个cls 第二个reg
        struct_names[1]: [],  # 第一个cls 第二个reg
        struct_names[2]: [[],[]],  # 第一个cls 第二个reg
    }
    for name, dat in data.items():
        # import ipdb; ipdb.set_trace()
        
        prefix, infix, suffix = name.split('.')[1], name.split('.')[2], name.split('.')[-1]
        
        if filter_mode in suffix:
            # suffix.split('_')[-1]
            for struct_name in struct_names:
                if struct_name in prefix:
                    if 'cls' in infix:
                        filtered_data[struct_name][0].append(dat[func_mode])
                    elif 'reg' in infix:
                        filtered_data[struct_name][1].append(dat[func_mode])
                    elif 'centerness' in infix or 'obj' in infix:
                        continue
                    else:
                        filtered_data[struct_name].append(dat[func_mode])
    # TODO 看怎么解决head的分支
    return filtered_data

def main(
    file_paths, 
    filter_mode, 
    func_mode, 
    branch_num,
    struct_names, 
    struct_marks,
    colors
):
    # plt.figure(figsize=(18,4), dpi=300)
    plt.figure(figsize=(20,2.5), dpi=300)
    # plt.subplots_adjust(left=0.04,bottom=0.18,right=0.99,top=0.96) # 设置子图间距
    plt.subplots_adjust(left=0.05,bottom=0.27,right=0.99,top=0.96) # 设置子图间距
    
    # data_list = []
    for file_path, color in zip(file_paths, colors):
        with open(file_path, 'r') as file:
            # data_list.append(json.load(file))
            data = json.load(file)
            
        # 过滤、获取限定情况的data
        filtered_data = filter_data_by_mode(data, struct_names, filter_mode, func_mode)
        # draw
        base_cord = 1
        # import ipdb; ipdb.set_trace()
        
        for struct_name in struct_names:
            struct_data = filtered_data[struct_name]
            if struct_name == 'head':
                for lt, partial_data in enumerate(struct_data):
                    x_idces = list(range(base_cord, base_cord + len(partial_data)))
                    x_idces_splited = split_list(x_idces, branch_num)
                    partial_data_splited = split_list(partial_data, branch_num)
                    for it in range(branch_num):
                        # if it != branch_num - 1:
                        plt.axvline(
                            x=x_idces_splited[it][0] - 0.5, 
                            linestyle='--', 
                            linewidth=2,
                            alpha=0.3, 
                            color='#999999'
                        )
                        plt.plot(
                            x_idces_splited[it], 
                            partial_data_splited[it], 
                            marker=struct_marks[struct_name][lt],
                            markersize=9,
                            color=color,
                            alpha=0.7
                            )
                        # plt.axvspan(
                        #     x_idces_splited[it][0] - 0.5, 
                        #     x_idces_splited[it][-1] + 0.5, 
                        #     alpha=0.3, 
                        #     color='#999999')
                base_cord = len(partial_data) + base_cord
                
            else:
                x_idces = list(range(base_cord, base_cord + len(struct_data)))
                base_cord = len(struct_data) + base_cord
                plt.plot(
                    x_idces, 
                    struct_data, 
                    marker=struct_marks[struct_name],
                    markersize=9,
                    color=color,
                    alpha=0.7
                    )
        pass
    
    # 显示图例
    plt.xlim(-0.5 + 1, base_cord)
    # plt.ylim(-0.1, 0.5)  # YOLOX SOS
    # plt.ylim(0, 2)  # YOLOX MOS
    # plt.ylim(-0.05, 0.15)  # ATSS SOS
    # plt.ylim(0, 0.7)  # ATSS MOS
    plt.ylabel(func_name, fontsize=17)
    plt.xlabel('Layer Index', fontsize=17)
    plt.xticks(fontsize=17)  # 设置X轴刻度标签的字体大小为14
    plt.yticks(fontsize=17)  # 设置Y轴刻度标签的字体大小为14
    import numpy as np
    plt.gca().set_xticks(np.arange(1, base_cord + 1, 2))
    legend_elements = [
        Line2D([0], [0], color=colors[0], lw=3, label='Baseline'),
        Line2D([0], [0], color=colors[1], lw=3, label='OAR (Ours)'),
        Line2D([0], [0], marker=struct_marks['backbone'], color='w', markerfacecolor='black', markersize=11, label='Backbone'),
        Line2D([0], [0], marker=struct_marks['neck'], color='w', markerfacecolor='black', markersize=9, label='Neck'),
        Line2D([0], [0], marker=struct_marks['head'][0], color='w', markerfacecolor='black', markersize=11, label='Cls Branch'),
        Line2D([0], [0], marker=struct_marks['head'][1], color='w', markerfacecolor='black', markersize=11, label='Reg Branch')
    ]
    plt.legend(handles=legend_elements, loc='upper left', fontsize=13, framealpha=0.45,
    ncol=2,  # 将图例分成两列
    columnspacing=1.0,  # 调整列间距
    handletextpad=0.5  # 调整图例条目之间的间距
    )

    plt.grid(axis='y')
    # plt.tight_layout()
    # 显示图表
    plt.show()
    # plt.savefig(f'y_function_vis/{func_mode}.png')
    func_mode = func_mode.replace('/', '-')
    plt.savefig(
        os.path.join(save_folder, f'{filter_mode}_{func_mode}_head.pdf')
        )
    
    

if __name__ == '__main__':
    # file_paths = [
    #     # "work_statistics/retinanet_r18_fpn_1x_coco/full-eval-statistics_fuse_True.json",
    #     # "work_statistics/retinanet_r18_fpn_1x_coco_QFOD/full-eval-statistics_fuse_True.json",
    #     # "work_statistics/retinanet_r50_fpn_1x_coco/full-eval-statistics_fuse_True.json",
    #     # "work_statistics/retinanet_r50_fpn_1x_coco_QFOD/full-eval-statistics_fuse_True.json",
    #     "work_statistics/atss_r50_fpn_1x_coco/full-eval-statistics_fuse_True.json",
    #     "work_statistics/atss_r50_fpn_1x_coco_QFOD/full-eval-statistics_fuse_True.json",
    #     # "work_statistics/yolox_tiny_coco/full-eval-statistics_fuse_True.json",
    #     # "work_statistics/yolox_tiny_coco_QFOD/full-eval-statistics_fuse_True.json",
    # ]
    # # perfect: max_mean max_std || std_std std_mean
    # # max_std std_std std_mean 
    # func_mode = 'std_mean'  #  max/std_3_mean    max/std_3_std
    # func_name = 'Mean of Stds'  # Std of Std  Mean of std    Mean of Max  
    # # func_mode = 'std_std'  #  max/std_3_mean    max/std_3_std
    # # func_name = 'Std of Stds'  # Std of Std  Mean of std    Mean of Max  
    
    branch_num = 5 # 5 3forYolox
    save_folder = 'work_statistics/atss_r50_fpn_1x_coco/_full_visualization'
    
    file_paths = [
        # "work_dirs/ptq/WminmaxAemamse_retinanet_r18_fpn_coco_w4a4/smallq4b_fuse_True.json",
        # "work_dirs/ptq/WminmaxAemamse_retinanet_r18_fpn_coco_ori-qfod_w4a4_QFOD/smallq4b_fuse_True.json",
        # "work_dirs/ptq/WminmaxAemamse_retinanet_r18_fpn_coco_w4a4/q4b_fuse_True.json",
        # "work_dirs/ptq/WminmaxAemamse_retinanet_r18_fpn_coco_ori-qfod_w4a4_QFOD/q4b_fuse_True.json",
        "work_dirs/ptq/WminmaxAemamse_atss_r50_fpn_coco_w4a4/1q4b_fuse_True.json",
        "work_dirs/ptq/WminmaxAemamse_atss_r50_fpn_coco_ori-qfod_w4a4/1q4b_fuse_True.json",
    ]
    func_mode = 'mse_loss'  #  max/std_3_mean    max/std_3_std
    func_name = r'Cumulated $E^Q$'  # Std of Std  Mean of std    Mean of Max  
    
    
    
    
    filter_mode = '_in'  # _out _in
    # filter_mode = '_out'  # _out _in
    
    struct_names = [
        'backbone',
        'neck',
        'head',
    ]
    struct_marks = {
        struct_names[0]: 'o',
        struct_names[1]: 'D',
        struct_names[2]: ['^', 'v'],   # 第一个为cls，第二个为reg  'o', 'D'  '^', 'v'
    }
    colors = [
        '#F27970', 
        '#2878B5'  # #54B345
        ]
    
    
    main(
        file_paths, 
        filter_mode, 
        func_mode, 
        branch_num,
        struct_names, 
        struct_marks,
        colors
        )