import json
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import os



def split_list(lst, n):
    """
    将列表 lst 等分为 n 个子列表
    """
    k, m = divmod(len(lst), n)
    return [lst[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n)]


def filter_data_by_mode(data, struct_names, filter_mode='in', func_mode='std_mean'):
    # filtered_data = {
    #     struct_names[0]: [],
    #     struct_names[1]: [],
    #     struct_names[2]: [[],[]]  # 第一个cls 第二个reg
    # }
    filtered_data = {
        struct_names[0]: [[],[]]  # 第一个cls 第二个reg
    }
    for name, dat in data.items():
        prefix, infix, suffix = name.split('.')[1], name.split('.')[2], name.split('.')[-1]
        
        if filter_mode in suffix:
            # suffix.split('_')[-1]
            for struct_name in struct_names:
                if struct_name in prefix:
                    if 'cls' in infix:
                        filtered_data[struct_name][0].append(dat[func_mode])
                    elif 'reg' in infix:
                        filtered_data[struct_name][1].append(dat[func_mode])
                    elif 'centerness' in infix or 'obj' in infix:
                        continue
                    else:
                        filtered_data[struct_name].append(dat[func_mode])
    # TODO 看怎么解决head的分支
    return filtered_data

def main(
    file_paths, 
    filter_mode, 
    func_mode, 
    branch_num,
    struct_names, 
    struct_marks,
    colors
):
    # plt.figure(figsize=(18,4), dpi=300)
    plt.figure(figsize=(11,3), dpi=300)
    # plt.subplots_adjust(left=0.04,bottom=0.18,right=0.99,top=0.96) # 设置子图间距
    plt.subplots_adjust(left=0.1,bottom=0.23,right=0.99,top=0.96) # 设置子图间距
    
    # data_list = []
    for file_path, color in zip(file_paths, colors):
        with open(file_path, 'r') as file:
            # data_list.append(json.load(file))
            data = json.load(file)
            
        # 过滤、获取限定情况的data
        filtered_data = filter_data_by_mode(data, struct_names, filter_mode, func_mode)
        import ipdb; ipdb.set_trace()
        # draw
        base_cord = 0
        for struct_name in struct_names:
            struct_data = filtered_data[struct_name]
            if struct_name == 'head':
                for lt, partial_data in enumerate(struct_data):
                    x_idces = list(range(base_cord + 1, base_cord + len(partial_data) + 1))
                    x_idces_splited = split_list(x_idces, branch_num)
                    partial_data_splited = split_list(partial_data, branch_num)
                    for it in range(branch_num):
                        plt.plot(
                            x_idces_splited[it], 
                            partial_data_splited[it], 
                            marker=struct_marks[struct_name][lt],
                            markersize=9,
                            color=color,
                            alpha=0.7
                            )
                        # plt.axvspan(
                        #     x_idces_splited[it][0] - 0.5, 
                        #     x_idces_splited[it][-1] + 0.5, 
                        #     alpha=0.3, 
                        #     color='#999999')
                        if it != branch_num - 1:
                            plt.axvline(
                                x=x_idces_splited[it][-1] + 0.5, 
                                linestyle='--', 
                                linewidth=2,
                                alpha=0.3, 
                                color='#999999'
                            )
                base_cord = len(partial_data) + base_cord
                
            else:
                x_idces = list(range(base_cord, base_cord + len(struct_data)))
                base_cord = len(struct_data) + base_cord
                plt.plot(
                    x_idces, 
                    struct_data, 
                    marker=struct_marks[struct_name],
                    color=color,
                    alpha=0.8
                    )
        pass
    
    # 显示图例
    plt.xlim(-0.5 + 1, base_cord - 0.5 + 1)
    plt.ylabel(func_name, fontsize=18)
    plt.xlabel('Layer Index', fontsize=18)
    plt.xticks(fontsize=18)  # 设置X轴刻度标签的字体大小为14
    plt.yticks(fontsize=18)  # 设置Y轴刻度标签的字体大小为14
    import numpy as np
    plt.gca().set_xticks(np.arange(1, base_cord + 1, 2))
    legend_elements = [
        Line2D([0], [0], color=colors[0], lw=3, label='Baseline'),
        Line2D([0], [0], color=colors[1], lw=3, label='OAR (Ours)'),
        Line2D([0], [0], marker=struct_marks['head'][0], color='w', markerfacecolor='black', markersize=11, label='Cls Branch'),
        Line2D([0], [0], marker=struct_marks['head'][1], color='w', markerfacecolor='black', markersize=11, label='Reg Branch')
    ]
    plt.legend(handles=legend_elements, loc='upper left', fontsize=12, framealpha=0.45)

    plt.grid(axis='y')
    # plt.tight_layout()
    # 显示图表
    plt.show()
    # plt.savefig(f'y_function_vis/{func_mode}.png')
    func_mode = func_mode.replace('/', '-')
    plt.savefig(
        os.path.join(save_folder, f'{filter_mode}_{func_mode}_head.pdf')
        )
    
    

if __name__ == '__main__':
    file_paths = [
        "work_statistics/retinanet_r18_fpn_1x_coco/full-eval-statistics_fuse_True.json",
        "work_statistics/retinanet_r18_fpn_1x_coco_QFOD/full-eval-statistics_fuse_True.json",
        # "work_statistics/retinanet_r50_fpn_1x_coco/full-eval-statistics_fuse_True.json",
        # "work_statistics/retinanet_r50_fpn_1x_coco_QFOD/full-eval-statistics_fuse_True.json",
        # "work_statistics/atss_r50_fpn_1x_coco/full-eval-statistics_fuse_True.json",
        # "work_statistics/atss_r50_fpn_1x_coco_QFOD/full-eval-statistics_fuse_True.json",
        # "work_statistics/yolox_tiny_coco/full-eval-statistics_fuse_True.json",
        # "work_statistics/yolox_tiny_coco_QFOD/full-eval-statistics_fuse_True.json",
    ]
    branch_num = 5 # 5 3forYolox
    save_folder = 'work_statistics/retinanet_r18_fpn_1x_coco/_full_visualization'
    
    # perfect: max_mean max_std || std_std std_mean
    # max_std std_std std_mean 
    func_mode = 'std_mean'  #  max/std_3_mean    max/std_3_std
    func_name = 'Mean of Stds'  # Std of Std  Mean of std    Mean of Max  
    
    func_mode = 'std_std'  #  max/std_3_mean    max/std_3_std
    func_name = 'Std of Stds'  # Std of Std  Mean of std    Mean of Max  
    
    filter_mode = '_in'  # _out _in
    
    struct_names = [
        'head'
    ]
    struct_marks = {
        struct_names[0]: ['^', 'v'],   # 第一个为cls，第二个为reg  'o', 'D'  '^', 'v'
    }
    colors = [
        '#F27970', 
        '#2878B5'  # #54B345
        ]
    
    
    main(
        file_paths, 
        filter_mode, 
        func_mode, 
        branch_num,
        struct_names, 
        struct_marks,
        colors
        )