import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# 读取已经清洗好的 CSV 文件
df = pd.read_csv('document/embeding/spatial_understanding/hard_topic_word_degrees_bad_cleaned.csv')

# 创建图形对象
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(12, 6), gridspec_kw={'height_ratios': [1, 2]})

# 按度数大小排序
df_sorted = df.sort_values(by='Degree', ascending=False)

# 对度数进行对数变换以压缩范围
degrees_log = np.log1p(df_sorted['Degree'])

# 设置 x 轴刻度
ax2.set_xticks(ticks=range(0, len(df_sorted['Topic Word']), 5))
ax2.set_xticklabels(labels=range(0, len(df_sorted['Topic Word']), 5), rotation=45)

# 根据对数变换后的度数设置点的颜色深度
colors = plt.cm.coolwarm((degrees_log - degrees_log.min()) / (degrees_log.max() - degrees_log.min()))

# 绘制折线图和散点图
ax1.plot(df_sorted['Topic Word'], df_sorted['Degree'], color='lightgray', linestyle='-', linewidth=2, alpha=0.5)
ax1.scatter(df_sorted['Topic Word'], df_sorted['Degree'], color=colors, marker='o', linewidth=1, alpha=1, s=df_sorted['Degree'] + 10)

ax2.plot(df_sorted['Topic Word'], df_sorted['Degree'], color='lightgray', linestyle='-', linewidth=2, alpha=0.5)
ax2.scatter(df_sorted['Topic Word'], df_sorted['Degree'], color=colors, marker='o', linewidth=1, alpha=1, s=df_sorted['Degree'] + 10)

# 设置两个子图的Y轴范围
ax1.set_ylim(110, 130)  # 顶部子图显示40到120的范围
ax2.set_ylim(0, 40)    # 底部子图显示0到40的范围

# 设置 Y 轴刻度
ax1.yaxis.set_major_locator(plt.MultipleLocator(5))  # 顶部子图只显示40和120的刻度
ax2.yaxis.set_major_locator(plt.MultipleLocator(5))   # 底部子图从0到40每隔5显示刻度

# 隐藏两个子图的边框
ax1.spines['bottom'].set_visible(False)
ax2.spines['top'].set_visible(False)

# 添加波浪线表示断层
d = .015  # 断层的尺寸
kwargs = dict(transform=ax1.transAxes, color='k', clip_on=False)
ax1.plot((-d, +d), (-d, +d), **kwargs)        # 左侧波浪线
ax1.plot((1 - d, 1 + d), (-d, +d), **kwargs)  # 右侧波浪线

kwargs.update(transform=ax2.transAxes)        # 更新transform以应用到ax2上
ax2.plot((-d, +d), (1 - d, 1 + d), **kwargs)  # 左侧波浪线
ax2.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)  # 右侧波浪线

# 设置轴标签
ax2.set_xlabel('Index of Topic Word', fontsize=18)
ax2.set_ylabel('Degree', fontsize=18)

# 添加网格
ax1.grid(True)
ax2.grid(True)

# 保存图像
plt.savefig('embeding_bad_with_broken_axis.png', dpi=1000)
plt.show()
