import argparse
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import FuncFormatter

# 设置图片清晰度
plt.rcParams['figure.dpi'] = 300
# 设置文字样式
plt.rcParams['font.weight'] = 'bold'
plt.rcParams['font.size'] = 20


def plot_gamma(df_gamma):
    # 提取每个实验的gamma值、precision、recall和SHD
    gamma_values = df_gamma.iloc[0, 1:].astype(float)  # 所有实验的gamma值
    precision = df_gamma.iloc[3, 1:].astype(float)  # precision数据
    recall = df_gamma.iloc[4, 1:].astype(float)  # recall数据
    shd = df_gamma.iloc[2, 1:].astype(int)  # SHD数据

    # 计算每个实验的F1 score
    f1_per_experiment = 2 * (precision * recall) / (precision + recall)

    # 创建包含所有实验数据的DataFrame并按gamma分组
    experiment_data = pd.DataFrame({
        'gamma': gamma_values,
        'f1': f1_per_experiment,
        'shd': shd
    })
    grouped_data = experiment_data.groupby('gamma').agg(
        mean_f1=('f1', 'mean'),
        std_f1=('f1', 'std'),
        mean_shd=('shd', 'mean'),
        std_shd=('shd', 'std')
    ).reset_index().sort_values('gamma')  # 按gamma值排序

    # 准备绘图数据
    x = grouped_data['gamma'] * 1000
    f1_mean = grouped_data['mean_f1']
    f1_std = grouped_data['std_f1']
    shd_mean = grouped_data['mean_shd']
    shd_std = grouped_data['std_shd']

    fig, ax = plt.subplots()
    color = 'tab:red'

    # 绘制F1 score的带误差条折线图，增加linewidth参数加宽折线
    ax.errorbar(x, f1_mean, yerr=f1_std, fmt='-o', color=color,
                ecolor='tab:red', elinewidth=2, capsize=3,
                label='F1 score', linewidth=3)
    ax.set_xlabel('gamma Values ($\\times 10^{-3}$)', fontweight='bold')
    ax.set_ylabel('F1 score', color=color, fontweight='bold')
    ax.set_ylim(0.4, 1)  # 设置F1纵轴范围
    ax.tick_params(axis='y', labelcolor=color)

    # 创建双轴并绘制SHD的带误差条折线图，增加linewidth参数加宽折线
    ax2 = ax.twinx()
    color = 'tab:blue'
    ax2.errorbar(x, shd_mean, yerr=shd_std, fmt='-o', color=color,
                 ecolor='tab:blue', elinewidth=2, capsize=3,
                 label='SHD', linewidth=3)
    ax2.set_ylabel('SHD', color=color, fontweight='bold')
    ax2.set_ylim(20, 50)  # 设置SHD纵轴范围
    ax2.tick_params(axis='y', labelcolor=color)

    # 合并图例
    lines, labels = ax.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax2.legend(lines + lines2, labels + labels2, loc='upper right')

    # 调整布局以避免标签显示不全
    plt.tight_layout()

    plt.show()


def plot_kq(df_kq):
    # 获取 K 值
    k_values = df_kq['K'].str.extract('(\d+)')[0].values
    # 获取 Q 值
    q_values = df_kq.columns[1:].str.extract('(\d+)')[0].values.astype(str)
    # 提取 Precision 和 Recall 数据
    precision_values = []
    recall_values = []
    for i in range(0, len(df_kq)):
        row_precision = []
        row_recall = []
        for j in range(1, len(df_kq.columns)):
            values = df_kq.iloc[i, j].split('/')
            row_precision.append(float(values[1]))
            row_recall.append(float(values[2]))
        precision_values.append(row_precision)
        recall_values.append(row_recall)
    precision_values = pd.DataFrame(precision_values, index=k_values, columns=q_values)
    recall_values = pd.DataFrame(recall_values, index=k_values, columns=q_values)
    # 计算 F1 score
    f1_scores = 2 * (precision_values * recall_values) / (precision_values + recall_values)
    fig, ax = plt.subplots()
    sns.heatmap(f1_scores, cmap='YlGnBu', ax=ax, annot=True, fmt='.2f', linewidths=1)
    ax.set_xlabel('Q', fontweight='bold')
    ax.set_ylabel('K', fontweight='bold')

    # 调整布局以避免标签显示不全
    plt.tight_layout()

    plt.show()


def main():
    parser = argparse.ArgumentParser(description='plot hyper-param analysis picture for gamma')
    parser.add_argument('--file', type=str, required=True, help='like 绘图数据.xlsx')
    args = parser.parse_args()
    excel_file = pd.ExcelFile(args.file)
    df_gamma = excel_file.parse('gamma', header=None)
    df_kq = excel_file.parse('KQ')
    plot_gamma(df_gamma)
    plot_kq(df_kq)


if __name__ == "__main__":
    main()
