import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

# 数据
em_data = {
    'Method': ['NaiveRAG', 'ToG-2', 'GraphRAG', 'LightRAG', 'MiniRAG', "HippoRAG-2",'Ours'],
    'HotpotQA': [0.634, 0.308,0.337,0.308, 0.213, 0.612,0.639],
    '2WikiMultihopQA': [0.382,0.401, 0.439, 0.420, 0.125,0.491 ,0.500],
    'Musique': [0.230,0.103 ,0.109,0.082, 0.067, 0.212,0.220]
}

# F1数据
f1_data = {
    'Method': ['NaiveRAG', 'ToG-2', 'GraphRAG', 'LightRAG', 'MiniRAG', "HippoRAG-2",'Ours'],
    'HotpotQA': [0.365, 0.153,0.011, 0.013, 0.012, 0.544,0.516],
    '2WikiMultihopQA': [0.189,0.194, 0.018, 0.023, 0.018,0.254, 0.267],
    'Musique': [0.143, 0.105,0.008, 0.009, 0.007, 0.145,0.133]
}

for data, title in [(em_data, "EM Score"), (f1_data, "F1 Score")]:
    # 使用RdBu配色方案
    colors = px.colors.sequential.RdBu
    colors = px.colors.sequential.RdBu_r
    # or use diverging colorscheme
    # colors = px.colors.diverging.RdYlBu
    # colors = px.colors.sequential.Viridis
    # colors = px.colors.diverging.Portland

    # 创建图形
    fig = go.Figure()

    len_of_colors = len(colors)

    # 颜色按照方法数量n等分，colors长度多余n 
    # 先确定分位数，percentile
    import numpy as np
    split_idx = [np.percentile(range(len_of_colors), i * 100 / (len(data['Method']) - 1)) for i in range(len(data['Method']))]
    colors_split = [colors[int(idx):int(next_idx)] for idx, next_idx in zip(split_idx, split_idx[1:] + [len_of_colors])]
    print("colors_split:", colors_split)
    colors = [group[-1] for group in colors_split]
    print("colors:", colors)


    method_colors = {
        'NaiveRAG': colors[0],
        'ToG-2': colors[1],
        'GraphRAG': colors[2],
        'LightRAG': colors[3],
        'MiniRAG': colors[4],
        "HippoRAG-2": colors[5],
        'Ours': colors[6]
    }

    # 为每个数据集添加柱状图，但按方法分组
    datasets = ['HotpotQA', '2WikiMultihopQA',  'Musique']
    dataset_names = ['HotpotQA', '2WikiMultihopQA',  'Musique']
    dataset_names = [f"<b>{name}</b>" for name in dataset_names]  # 加粗

    # 按方法循环，而不是按数据集
    for i, method in enumerate(data['Method']):
        # 收集该方法在所有数据集中的值
        y_values = [data[dataset][i] for dataset in datasets]
        
        fig.add_trace(go.Bar(
            name=method,
            x=dataset_names,
            y=y_values,
            # width=[0.2]*len(y_values), # update
            marker_color=method_colors[method],
            marker_line_color='rgba(0,0,0,0.5)',
            marker_line_width=1,
            opacity=0.8,
            legendgroup=method,
            showlegend=True
        ))

    # 更新布局
    fig.update_layout(
        xaxis_title='<b>Datasets</b>',
        # yaxis_title='EM Score',
        yaxis_title=f'<b>{title}</b>',
        barmode='group',  # 分组显示
        template='plotly_white',
        font=dict(size=14, color='black', family="Times New Roman",weight='bold'),
        legend=dict(
            # orientation="v", # 设置图例垂直排列, horizontal排列 "h"
            # yanchor="top", # 设置图例在y轴上的位置是底部
            # y=1.02, # 设置图例在y轴上的位置
            xanchor="right", # 设置图例在x轴上的位置
            # x=1,
            bordercolor="Black",
        ),
        plot_bgcolor='white',
        showlegend=True if title == "F1 Score" else False,
        font_family="Times New Roman",
        # yaxis=dict(
        #     range=[0, 0.55]  # 设置y轴范围以更好地显示数据
        # )
    )

    # 显示图形
    # fig.show()

    # 显示图形
    # fig.show()
    # fig.write_image("bar_chart.png", width=800, height=600)
    # save pdf
    fig.write_image(f"scripts/figs/{title.split(' ')[0].lower()}_bar_chart.pdf")