import pandas as pd
import replicate
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
import json
from matplotlib.colors import Normalize



# 读取CSV文件
file_path = '/Users/wad3/Downloads/paper/visual_autobench/document/embeding/spatial_understanding/hard_topic_word_degrees_good.csv'
df = pd.read_csv(file_path)

# 提取texts和degree
texts = df['Topic Word'].tolist()
texts = json.dumps(texts)
degrees = df['Degree'].tolist()

# 调用模型计算向量
# output = replicate.run(
#     "nateraw/bge-large-en-v1.5:9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1",
#     input={
#         "texts": texts,
#         "batch_size": 32,
#         "convert_to_numpy": False,
#         "normalize_embeddings": True
#     }
# )

# # 将输出转换为数组
# embeddings = np.array(output)
embeddings = np.load('/Users/wad3/Downloads/paper/visual_autobench/document/embeding/embeddings_bad.npy')
# 使用t-SNE进行降维
tsne = TSNE(n_components=2, random_state=42)
embeddings_tsne = tsne.fit_transform(embeddings)

# 添加抖动（微小扰动）
jitter_strength = 1 # 控制抖动强度的参数
embeddings_tsne += np.random.normal(scale=jitter_strength, size=embeddings_tsne.shape)

# 创建可视化
plt.figure(figsize=(10, 8))
norm = Normalize(vmin=5, vmax=40)
# 可选的 cmap 包括:
# 连续色彩映射: 'viridis', 'inferno', 'magma', 'cividis', 'YlOrRd', 'YlGnBu', 'RdYlBu'
# 发散色彩映射: 'coolwarm', 'bwr', 'seismic', 'RdBu', 'PiYG'
# 循环色彩映射: 'hsv', 'twilight', 'twilight_shifted'
# 定性色彩映射: 'Set1', 'Set2', 'Set3', 'Paired', 'tab10', 'tab20'

# 保存向量到文件
# np.save('embeddings.npy', embeddings)

# 如果需要加载向量，可以使用:
# 创建散点图
scatter = plt.scatter(
    embeddings_tsne[:, 0],  # x坐标，使用t-SNE降维后的第一个分量
    embeddings_tsne[:, 1],  # y坐标，使用t-SNE降维后的第二个分量
    c=degrees,              # 点的颜色，根据度数值来确定
    cmap='coolwarm',         # 颜色映射方案，使用'viridis'色彩方案
    s=200,                   # 点的大小，设置为70
    edgecolors='face',      # 点的边缘颜色设置为与点自身颜色相同
    alpha=0.7,              # 点的透明度，设置为0.5
    norm=norm               # 颜色标准化对象，用于将度数值映射到颜色范围
)
# 这行代码创建了一个散点图，每个点代表一个文本嵌入，位置由t-SNE降维结果决定，
# 颜色表示度数，大小和透明度固定，边缘为白色，使用viridis颜色方案。

# 添加颜色条
cbar = plt.colorbar(scatter)
cbar.set_label('Degree', rotation=270, labelpad=20, fontsize=24)  # 增大字体和调整标签间距
cbar.ax.tick_params(labelsize=20)  # 增大颜色条数字的字体大小
plt.subplots_adjust(right=0.952, left=0.176, top=0.88, bottom=0.11)
# plt.title('t-SNE visualization of text embeddings with jitter', fontsize=16)
plt.xlabel('t-SNE component 1', fontsize=24)
plt.ylabel('t-SNE component 2', fontsize=24)
plt.tick_params(axis='both', which='major', labelsize=24)
plt.savefig('semantic_good.pdf', dpi = 300)
plt.show()