import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import pandas as pd

def calculate_correct_elo(win_rates_dict):
    """
    正确计算ELO评分
    win_rates_dict: 各个方法对tog3的胜率
    """
    ratings = {'tog3': 1600}
    
    for method, win_rate_against_tog3 in win_rates_dict.items():
        if method != 'tog3':
            rating_diff = 400 * np.log10((1 / win_rate_against_tog3) - 1)
            ratings[method] = ratings['tog3'] - rating_diff
    
    return ratings

# 四个数据集的对战数据
datasets = {
    'Agriculture': {
        'graph': 0.476, 'light': 0.396, 'naive': 0.287, 'hippo': 0.256
    },
    'CS': {
        'graph': 0.492, 'light': 0.460, 'naive': 0.333, 'hippo': 0.320
    },
    'Legal': {
        'graph': 0.484, 'light': 0.324, 'naive': 0.112, 'hippo': 0.204
    },
    'Mix': {
        'graph': 0.512, 'light': 0.496, 'naive': 0.360, 'hippo': 0.316
    }
}

# 按照指定顺序排列：ToG 3.0在左上角
methods = ['tog3', 'graph', 'light', 'naive', 'hippo']
# method_labels = ['ToG 3.0', 'GraphRAG', 'LightRAG', 'NaiveRAG', 'HippoRAG2']
method_labels = ['ToG-3', 'GraphRAG', 'LightRAG', 'NaiveRAG', 'HippoRAG-2']

# 创建2×2排列的子图
fig = make_subplots(
    rows=2, cols=2,
    subplot_titles=[f'<b>{name}</b>' for name in datasets.keys()],
    horizontal_spacing=0.12,
    vertical_spacing=0.15,
    specs=[[{"type": "heatmap"}, {"type": "heatmap"}],
           [{"type": "heatmap"}, {"type": "heatmap"}]]
)
fig.update_annotations(font_size=12)

# ICLR风格的颜色方案
iclr_colorscale = [
    [0.0, '#2166ac'],   # 深蓝色 - 低胜率
    [0.25, '#4393c3'],  # 蓝色
    [0.5, '#f7f7f7'],   # 白色 - 50%胜率
    [0.75, '#d6604d'],  # 红色
    [1.0, '#b2182b']    # 深红色 - 高胜率
]

for idx, (dataset_name, win_rates) in enumerate(datasets.items()):
    # 计算ELO评分
    elo_ratings = calculate_correct_elo(win_rates)
    
    # 创建胜率矩阵 - 确保ToG 3.0在左上角
    win_rate_matrix = np.zeros((len(methods), len(methods)))
    for i, method1 in enumerate(methods):
        for j, method2 in enumerate(methods):
            if i == j:
                win_rate_matrix[i, j] = 0.5
            else:
                rating_diff = elo_ratings[method1] - elo_ratings[method2]
                win_rate = 1 / (1 + 10 ** (-rating_diff / 400))
                win_rate_matrix[i, j] = win_rate

    win_rate_matrix = win_rate_matrix[::-1]  # 反转行顺序
    reversed_method_labels = method_labels[::-1]  # 反转标签顺序
    print(f"\n{dataset_name} Win Rate Matrix:")
    print(win_rate_matrix)

    # 确定行列位置
    row = idx // 2 + 1
    col = idx % 2 + 1
    
    # 创建热力图 - 设置正方形格子
    heatmap = go.Heatmap(
        z=win_rate_matrix,
        x=[f"<b>{label}</b>" for label in method_labels],
        y=[f"<b>{label}</b>" for label in reversed_method_labels],
        # colorscale=iclr_colorscale,
        colorscale="RdBu_r",
        zmin=0,
        zmax=1,
        colorbar=dict(
            title=dict(text="Win Rate", font=dict(size=12)),
        #     len=0.4,
        #     y=0.75,
        #     yanchor='middle',
        #     tickvals=[0, 0.25, 0.5, 0.75, 1],
        #     ticktext=['0%', '25%', '50%', '75%', '100%'],
            thickness=15,
            # font = dict(family="Times New Roman", weight='bold')
        ),
        # hovertemplate=(
        #     '<b>Attacker</b>: %{y}<br>'
        #     '<b>Defender</b>: %{x}<br>'
        #     '<b>Win Rate</b>: %{z:.3f}<br>'
        #     '<extra></extra>'
        # ),
        text=np.round(win_rate_matrix, 3),
        texttemplate='%{text}',
        textfont={"size": 7, "color": "black"},
        showscale=(idx == 1),  # 只在第二个子图显示colorbar
        xgap=1,
        ygap=1
    )
    
    fig.add_trace(heatmap, row=row, col=col)

# 更新布局 - 确保正方形格子
fig.update_layout(
    # title=dict(
    #     text='<b>Pairwise Win Rate Comparison Across Four Benchmark Datasets</b><br><sup>ELO Rating System Derived from One-to-One Method Comparisons</sup>',
    #     x=0.5,
    #     xanchor='center',
    #     y=0.97,
    #     font=dict(size=20, family='Arial')
    # ),
    # height=900,  # 增加高度以适应2×2布局
    # width=900,   # 正方形整体布局
    # font=dict(family='Arial', size=11),
    font=dict(family='Times New Roman', size=7, color='black'),
    plot_bgcolor='white',
    paper_bgcolor='white',
    font_family="Times New Roman",
    legend=dict(
        font=dict(size=7, family="Times New Roman", weight='bold'),
    ),
    showlegend=True,
    # margin=dict(l=20, r=20, t=20, b=20),# left, right, top, bottom margins in pixels
    # autosize=True
)

# 更新子图样式 - 设置正方形比例
# for i in range(4):
#     row = i // 2 + 1
#     col = i % 2 + 1
#     fig.update_xaxes(
#         tickangle=20,
#         tickfont=dict(size=7),
#         row=row, col=col,
#         constrain='domain'
#     )
#     fig.update_yaxes(
#         tickfont=dict(size=7),
#         row=row, col=col,
#         scaleanchor=f"x{i+1}",
#         scaleratio=1
#     )

# 显示图形
# fig.show()

# 打印ELO评分表格
print("ELO Ratings Across Datasets:")
print("=" * 65)
elo_data = []
for dataset_name, win_rates in datasets.items():
    elo_ratings = calculate_correct_elo(win_rates)
    elo_data.append([dataset_name] + [int(elo_ratings[method]) for method in methods])

df_elo = pd.DataFrame(elo_data, columns=['Dataset'] + method_labels)
print(df_elo.to_string(index=False))

# 显示图形
# fig.show()
# save to pdf
fig.write_image("elo_win_rates_heatmap.pdf", format='pdf')
