# -*- coding: utf-8 -*-
# 基于给定日志的 Acc / Loss 可视化（柱状图+折线图，双坐标轴）
# 要求：每5个epoch为一组；组内 Acc 下降的柱子标红；Loss 上升的点标红

import matplotlib.pyplot as plt
from utils.visual import init_mpl

init_mpl()

# ----------------------------
# 1) 手动录入每组(Adapt 1..5)的5个 epoch 的 Acc / Loss
# ----------------------------
# data = {
#     "Adapt 1": {
#         "acc":  [0.9928, 0.9928, 0.9928, 0.9928, 0.9928],
#         "loss": [0.0313, 0.0314, 0.0315, 0.0315, 0.0314],
#     },
#     "Adapt 2": {
#         "acc":  [0.9769, 0.9773, 0.9774, 0.9775, 0.9778],
#         "loss": [0.1114, 0.1126, 0.1135, 0.1142, 0.1139],
#     },
#     "Adapt 3": {
#         "acc":  [0.9387, 0.9386, 0.9395, 0.9399, 0.9403],
#         "loss": [0.3702, 0.3787, 0.3799, 0.3839, 0.3872],
#     },
#     "Adapt 4": {
#         "acc":  [0.8416, 0.8423, 0.8411, 0.8421, 0.8424],
#         "loss": [1.2520, 1.2522, 1.2795, 1.2873, 1.3071],
#     },
#     "Adapt 5": {
#         "acc":  [0.7145, 0.7153, 0.7154, 0.7158, 0.7161],
#         "loss": [2.6718, 2.6457, 2.6401, 2.6452, 2.6646],
#     },
# }

data = {
    "Adapt 1": {
        "acc":  [0.9856, 0.9857, 0.9860, 0.9860, 0.9860],
        "loss": [0.0508, 0.0506, 0.0501, 0.0497, 0.0507],
    },
    "Adapt 2": {
        "acc":  [0.9593, 0.9600, 0.9598, 0.9601, 0.9605],
        "loss": [0.1836, 0.1812, 0.1825, 0.1822, 0.1829],
    },
    "Adapt 3": {
        "acc":  [0.8961, 0.8963, 0.8960, 0.8961, 0.8973],
        "loss": [0.5806, 0.5900, 0.5960, 0.5972, 0.5958],
    },
    "Adapt 4": {
        "acc":  [0.7619, 0.7618, 0.7618, 0.7608, 0.7623],
        "loss": [1.6182, 1.6584, 1.6879, 1.7238, 1.7152],
    },
    "Adapt 5": {
        "acc":  [0.6060, 0.6059, 0.6053, 0.6056, 0.6047],
        "loss": [3.4810, 3.5445, 3.6056, 3.6475, 3.7063],
    },
}


# ----------------------------
# 2) 展平为 25 个 epoch 序列，并记录组边界
# ----------------------------
acc_all, loss_all, groups, xlabels = [], [], [], []
for gi, (gname, d) in enumerate(data.items(), start=1):
    accs = d["acc"]
    losss = d["loss"]
    acc_all.extend(accs)
    loss_all.extend(losss)
    groups.extend([gname]*len(accs))
    # xlabels.extend([f"E{j}" for j in range(1, len(accs)+1)])

N = len(acc_all)  # 25
xs = list(range(1, N+1))

# ----------------------------
# 3) 计算需要标红的位置
#    - Acc 下降：与本组前一个 epoch 相比 acc[i] < acc[i-1]
#    - Loss 上升：与本组前一个 epoch 相比 loss[i] > loss[i-1]
# ----------------------------
acc_drop_idx = []   # 需要红色柱的 x 索引
loss_rise_idx = []  # 需要红色点的 x 索引

start = 0
for gname, d in data.items():
    accs = d["acc"]
    losss = d["loss"]
    # 组内比较（从第二个 epoch 起）
    for j in range(1, len(accs)):
        global_idx = start + j  # 1-based: xs[global_idx]
        if accs[j] < accs[j-1]:
            acc_drop_idx.append(global_idx+1)  # +1 对应 xs 的位置
        if losss[j] > losss[j-1]:
            loss_rise_idx.append(global_idx+1)
    start += len(accs)

# ----------------------------
# 4) 绘图
# ----------------------------
fig, ax1 = plt.subplots(figsize=(10, 4))
ax2 = ax1.twinx()

# (a) 先画所有 Acc 柱（默认颜色）
bar_container = ax1.bar(xs, acc_all, label="Accuracy", alpha=0.85, width=0.8, color="#59AC77")

ax1.set_ylim(0.5, 1.0)

# (b) 叠加画“Acc 下降”的红色柱：只在下降的那些位置覆盖同宽柱子
if acc_drop_idx:
    acc_drop_vals = [acc_all[i-1] for i in acc_drop_idx]  # i-1 转为 0-based
    ax1.bar(acc_drop_idx, acc_drop_vals, label="Acc. decreased", alpha=0.95, color="#3396D3")

# (c) 画 Loss 折线（默认颜色）
line_loss, = ax2.plot(xs, loss_all, marker="^", label="Loss", linewidth=2, markersize=10, color="#FAA533")

# (d) 叠加画“Loss 上升”的红色点（覆盖原点）
if loss_rise_idx:
    loss_rise_vals = [loss_all[i-1] for i in loss_rise_idx]
    ax2.scatter(loss_rise_idx, loss_rise_vals, label="Loss increased", s=100, color="#ED3F27", zorder=5, marker="^")

# ----------------------------
# 5) 美化：组分隔线、组标签、坐标轴与网格
# ----------------------------
# 每组 5 个 epoch，分隔线画在 5.5, 10.5, 15.5, 20.5
for k in range(1, len(data)):
    ax1.axvline(x=k*5 + 0.5, linestyle="--", alpha=1, color="#3A6F43")

# 在每组中心位置写组名
group_names = list(data.keys())
for gi, gname in enumerate(group_names, start=0):
    center_x = gi*5 + 3  # 该组的5个 x 为 [5*gi+1 ... 5*gi+5]，中心约在 +3
    ax1.text(center_x, ax1.get_ylim()[1]*0.65, gname, ha="center", va="top")

# X 轴刻度仅显示 E1..E5 循环，可选：也可直接用 1..25
# ax1.set_xticks(xs)
# ax1.set_xticklabels(xlabels, rotation=0)
# 无刻度
ax1.set_xticks([])

# ax1.set_xlabel("Epoch")
ax1.set_ylabel("Accuracy")
ax2.set_ylabel("Loss")

ax1.grid(axis="y", alpha=0.25)

# 图例（合并左右轴的图例）
handles1, labels1 = ax1.get_legend_handles_labels()
handles2, labels2 = ax2.get_legend_handles_labels()
# ax1.legend(handles1 + handles2, labels1 + labels2, loc="upper left")
ax1.legend(handles1 + handles2, labels1 + labels2,
           loc='center left',        # 左侧中间
           bbox_to_anchor=(0.02, 0.5))  # 稍微往里 0.02，垂直 0.5

plt.tight_layout()

# 可选：保存图片
# plt.savefig("acc_loss_dual_axis.png", dpi=200)
plt.savefig("acc_loss_dual_axis.svg") 

# plt.show()
