import json
import matplotlib.pyplot as plt





def split_list(lst, n):
    """
    将列表 lst 等分为 n 个子列表
    """
    k, m = divmod(len(lst), n)
    return [lst[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n)]


def filter_data_by_mode(data, struct_names, filter_mode='in', func_mode='std_mean'):
    # filtered_data = {
    #     struct_names[0]: [],
    #     struct_names[1]: [],
    #     struct_names[2]: [[],[]]  # 第一个cls 第二个reg
    # }
    filtered_data = {
        struct_names[0]: [[],[]]  # 第一个cls 第二个reg
    }
    for name, dat in data.items():
        prefix, infix, suffix = name.split('.')[1], name.split('.')[2], name.split('.')[-1]
        
        if filter_mode in suffix:
            # suffix.split('_')[-1]
            for struct_name in struct_names:
                if struct_name in prefix:
                    if 'cls' in infix:
                        filtered_data[struct_name][0].append(dat[func_mode])
                    elif 'reg' in infix:
                        filtered_data[struct_name][1].append(dat[func_mode])
                    elif 'centerness' in infix:
                        continue
                    else:
                        filtered_data[struct_name].append(dat[func_mode])
    # TODO 看怎么解决head的分支
    return filtered_data

def main(
    file_paths, 
    filter_mode, 
    func_mode, 
    struct_names, 
    struct_marks,
    colors
):
    # plt.figure(figsize=(18,4), dpi=300)
    plt.figure(figsize=(16,4), dpi=300)
    # plt.subplots_adjust(left=0.11,bottom=0.15,right=0.98,top=0.98,wspace=.2,hspace=.2) # 设置子图间距
    
    # data_list = []
    for file_path, color in zip(file_paths, colors):
        with open(file_path, 'r') as file:
            # data_list.append(json.load(file))
            data = json.load(file)
            
        # 过滤、获取限定情况的data
        filtered_data = filter_data_by_mode(data, struct_names, filter_mode, func_mode)
        
        # draw
        base_cord = 0
        for struct_name in struct_names:
            struct_data = filtered_data[struct_name]
            if struct_name == 'head':
                for lt, partial_data in enumerate(struct_data):
                    x_idces = list(range(base_cord, base_cord + len(partial_data)))
                    branch_num = 5
                    x_idces_splited = split_list(x_idces, branch_num)
                    partial_data_splited = split_list(partial_data, branch_num)
                    for it in range(5):
                        plt.plot(
                            x_idces_splited[it], 
                            partial_data_splited[it], 
                            marker=struct_marks[struct_name][lt],
                            color=color,
                            alpha=0.8
                            )
                base_cord = len(partial_data) + base_cord
                
            else:
                x_idces = list(range(base_cord, base_cord + len(struct_data)))
                base_cord = len(struct_data) + base_cord
                plt.plot(
                    x_idces, 
                    struct_data, 
                    marker=struct_marks[struct_name],
                    color=color,
                    alpha=0.8
                    )
        pass
    
    # 显示图例
    plt.xlim(-1, base_cord + 1)
    plt.legend()

    plt.grid()
    plt.tight_layout()
    # 显示图表
    plt.show()
    # plt.savefig(f'y_function_vis/{func_mode}.png')
    plt.savefig('y_function_vis/tmp3.png')
    
    

if __name__ == '__main__':
    file_paths = [
        # "work_statistics/retinanet_r18_fpn_1x_coco/full-eval-statistics_fuse_True.json",
        # "work_statistics/retinanet_r18_fpn_1x_coco_QFOD/full-eval-statistics_fuse_True.json",
        "work_statistics/atss_r50_fpn_1x_coco/full-eval-statistics_fuse_True.json",
        "work_statistics/atss_r50_fpn_1x_coco_QFOD/full-eval-statistics_fuse_True.json",
    ]
    filter_mode = 'in'  # out in weight
    # perfect: std_std  max_mean
    # max_std std_std std_mean 
    func_mode = 'max_mean'  #  max/std_3_mean    max/std_3_std
    
    # struct_names = [
    #     'backbone',
    #     'neck',
    #     'head'
    # ]
    # struct_marks = {
    #     struct_names[0]: '.',
    #     struct_names[1]: 'x',
    #     struct_names[2]: ['^', 'v'],   # 第一个为cls，第二个为reg  'o', 'D'  '^', 'v'
    # }
    # vspan_colors = {
    #     struct_names[0]: 'o',
    #     struct_names[1]: 'x',
    #     struct_names[2]: '*',
    # }
    # colors = [
    #     '#F27970', 
    #     '#2878B5'  # #54B345
    #     ]
    
    
    struct_names = [
        'head'
    ]
    struct_marks = {
        struct_names[0]: ['^', 'v'],   # 第一个为cls，第二个为reg  'o', 'D'  '^', 'v'
    }
    colors = [
        '#F27970', 
        '#2878B5'  # #54B345
        ]
    
    
    main(
        file_paths, 
        filter_mode, 
        func_mode, 
        struct_names, 
        struct_marks,
        colors
        )