from matplotlib import pyplot as plt
import matplotlib.image as mpimg
from pathlib import Path
import os


def grad_hist(grads, dir_name="grad", edge=1e-5):
    # 绘制子图
    if not os.path.exists(dir_name):
        os.mkdir(dir_name)
    for i, grad in enumerate(grads):
        # grad = grad.detach().cpu()
        # grad = grad[(-edge <= grad) & (grad <= edge)]
        plt.figure(figsize=(6, 8), dpi=300)
        # plt.xlim(-edge, edge)
        plt.hist(grad, bins=100)
        plt.savefig(f"{dir_name}/{i}.png")
        plt.close()

    image_paths = [f"{dir_name}/{i}.png" for i in range(len(grads))]

    # # 创建子图
    # n_cols = 5
    # n_rows = len(grads) // 5 + (0 if len(grads) % 5 == 0 else 1)
    # fig, axes = plt.subplots(n_rows, n_cols, figsize=(7 * n_cols, 9 * n_rows), dpi=600)

    # # 绘制全图
    # for idx, ax in enumerate(axes.flat):
    #     if idx < len(image_paths):
    #         img = mpimg.imread(image_paths[idx])
    #         ax.imshow(img)
    #         ax.set_title(f"Iter {idx+1}")
    #     ax.axis("off")
    # plt.tight_layout()
    # plt.savefig(f"{dir_name}/{dir_name}.png")
