
import re

import json
import glob
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.cluster.hierarchy import linkage, dendrogram, leaves_list
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import adjusted_rand_score
# ===============================
# 用户配置区
# ===============================
#  your task_grouping json result
INPUT_PATTERN = "./logdir/your logname/task_grouping_output/{your prefix}_step*.json"
# 输出文件
prefix = "dmc_vision"
# 最终需要分成的组数
N_CLUSTERS = 10
# 时间步权重设置方式（后期时间步更可靠）
# 可选："linear"（线性增加）或 "exp"（指数增加）
WEIGHT_MODE = "linear"
total_steps = 250000

OUTPUT_FILE = f"{prefix}_final_groups.json"
files = sorted(glob.glob(INPUT_PATTERN))
if not files:
    raise FileNotFoundError(f"未找到匹配的文件：{INPUT_PATTERN}")
def extract_step(file_path):
    """从文件路径中提取step后的数字（step），提取失败返回None"""
    # 正则匹配：匹配step后面的数字（\d+表示1个及以上数字）
    match = re.search(r'step(\d+)', file_path)
    if match:
        return int(match.group(1))  # 提取数字并转为整数
    return None  # 非time_数字.json格式的文件返回None


filtered_files = []
for file in files:
    step = extract_step(file)
    if step is not None and step <= total_steps:  # 仅保留有效step且≤k的文件
        filtered_files.append(file)
files = filtered_files

clusters = []       # 每个时间步的 {任务: group名称}
tasks_set = set()   # 记录所有任务名


for f in files:
    with open(f, "r", encoding="utf-8") as fp:
        data = json.load(fp)
    cluster_dict = {}
    for group_id, members in data.items():
        for m in members:
            cluster_dict[m] = group_id
            tasks_set.add(m)
    clusters.append(cluster_dict)

tasks = sorted(list(tasks_set))
n = len(tasks)
T = len(clusters)


cluster_matrix = np.zeros((T, n), dtype=int)
for t, cluster_dict in enumerate(clusters):
    unique_groups = sorted(set(cluster_dict.values()))
    group_to_id = {g: i for i, g in enumerate(unique_groups)}
    for i, task in enumerate(tasks):
        cluster_matrix[t, i] = group_to_id[cluster_dict[task]]


if WEIGHT_MODE == "linear":
    weights = np.linspace(1, 2, T)
elif WEIGHT_MODE == "exp":
    weights = np.exp(np.linspace(0, 1, T))
else:
    weights = np.ones(T)

A = np.zeros((n, n))
for t in range(T):
    for i in range(n):
        for j in range(n):
            if cluster_matrix[t, i] == cluster_matrix[t, j]:
                A[i, j] += weights[t]
A /= np.sum(weights)


model = AgglomerativeClustering(
    n_clusters=N_CLUSTERS,
    affinity="precomputed",
    linkage="average"
)
labels_final = model.fit_predict(1 - A)

# 生成可读的分组字典
final_groups = {f"group_{k}": [] for k in range(N_CLUSTERS)}
for i, label in enumerate(labels_final):
    final_groups[f"group_{label}"].append(tasks[i])

# 保存结果
with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
    json.dump(final_groups, f, indent=2, ensure_ascii=False)


for t in range(T):
    ari = adjusted_rand_score(cluster_matrix[t], labels_final)
    print(f"  timestep {t+1}: ARI = {ari:.3f}")

# ===============================
# visualization
# ===============================

palette = sns.color_palette("tab10", N_CLUSTERS)
label_colors = [palette[l] for l in labels_final]

linked = linkage(1 - A, method='average')
plt.figure(figsize=(12, 6))
dend = dendrogram(linked, labels=tasks, leaf_rotation=90, leaf_font_size=9,
                  link_color_func=lambda k: "grey")


ax = plt.gca()
x_labels = ax.get_xmajorticklabels()
for lbl in x_labels:
    task = lbl.get_text()
    i = tasks.index(task)
    lbl.set_color(label_colors[i])

plt.title("Hierarchical Dendrogram", fontsize=14)
plt.tight_layout()
plt.savefig(f"{prefix}_hierarchical_dendrogram_colored.png", dpi=300)
plt.close()


order = leaves_list(linked)
A_sorted = A[order][:, order]
labels_sorted = [labels_final[i] for i in order]
tasks_sorted = [tasks[i] for i in order]

plt.figure(figsize=(10, 8))
sns.heatmap(A_sorted, cmap="viridis", xticklabels=tasks_sorted, yticklabels=tasks_sorted)
plt.title("Co Association Heatmap", fontsize=14)


for i, group_id in enumerate(labels_sorted):
    plt.gca().add_patch(plt.Rectangle((-0.5, i-0.5), -0.2, 1, color=palette[group_id], lw=0))
    plt.gca().add_patch(plt.Rectangle((n-0.5, i-0.5), 0.2, 1, color=palette[group_id], lw=0))

plt.tight_layout()
plt.savefig(f"{prefix}_co_association_heatmap_colored.png", dpi=300)
plt.close()

