import os
import numpy as np
import torch
import matplotlib.pyplot as plt

def reduce_to_2d(arr):
    """
    将多维数组通过连续取均值降维为二维数组。
    如果数组为标量或一维，则调整其形状为二维。
    :param arr: numpy 数组
    :return: 降维后的二维数组
    """
    if arr.ndim == 0:
        return arr.reshape(1, 1)
    if arr.ndim == 1:
        return arr.reshape(1, -1)
    while arr.ndim > 2:
        arr = arr.mean(axis=0)
    return arr

def visualize_mean_abs_deltas(mean_abs_deltas, title="Mean Absolute Deltas", save_path=None):
    """
    分层展示模型参数的平均绝对差异。对于维度大于2的参数，先降维到二维后显示。
    :param mean_abs_deltas: dict，每个 key 对应一个模型参数的平均绝对差异 tensor
    :param title: 图像标题
    :param save_path: 保存图像的完整路径。如果为 None，则根据 title 生成文件名
    """
    keys = list(mean_abs_deltas.keys())
    if len(keys) == 0:
        fig = plt.figure(figsize=(6, 4))
        plt.text(0.5, 0.5, "No data to display", ha="center", va="center", fontsize=12)
        plt.title(title)
        if save_path is None:
            safe_title = "".join(c if c.isalnum() or c in "._-" else "_" for c in title)
            save_path = safe_title + ".png"
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)
        plt.close(fig)
        return

    n_layers = len(keys)
    n_cols = min(4, n_layers)
    n_rows = (n_layers + n_cols - 1) // n_cols

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))
    if n_layers == 1:
        axes = [axes]
    else:
        axes = np.array(axes).reshape(-1)
    for ax, key in zip(axes, keys):
        delta_array = mean_abs_deltas[key].cpu().numpy()
        delta_array = reduce_to_2d(delta_array)
        im = ax.imshow(delta_array, cmap='hot', interpolation='nearest')
        ax.set_title(str(key))
        ax.axis('off')
        fig.colorbar(im, ax=ax)
    for ax in axes[len(keys):]:
        ax.axis('off')
        
    fig.suptitle(title)
    if save_path is None:
        safe_title = "".join(c if c.isalnum() or c in "._-" else "_" for c in title)
        save_path = safe_title + ".png"
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path)
    plt.close(fig)

def visualize_mask(mask, title="Mask", save_path=None):
    """
    分层展示显著性 mask。对于维度大于2的 mask，先降维到二维后显示。
    :param mask: dict，每个 key 对应一个 mask tensor
    :param title: 图像标题
    :param save_path: 保存图像的完整路径。如果为 None，则根据 title 生成文件名
    """
    keys = list(mask.keys())
    if len(keys) == 0:
        fig = plt.figure(figsize=(6, 4))
        plt.text(0.5, 0.5, "No mask data to display", ha="center", va="center", fontsize=12)
        plt.title(title)
        if save_path is None:
            safe_title = "".join(c if c.isalnum() or c in "._-" else "_" for c in title)
            save_path = safe_title + ".png"
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)
        plt.close(fig)
        return

    n_layers = len(keys)
    n_cols = min(4, n_layers)
    n_rows = (n_layers + n_cols - 1) // n_cols

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))
    if n_layers == 1:
        axes = [axes]
    else:
        axes = np.array(axes).reshape(-1)
    for ax, key in zip(axes, keys):
        mask_array = mask[key].cpu().numpy()
        mask_array = reduce_to_2d(mask_array)
        im = ax.imshow(mask_array, cmap='hot', vmin=0, vmax=1, interpolation='nearest')
        ax.set_title(str(key))
        ax.axis('off')
        fig.colorbar(im, ax=ax)
    for ax in axes[len(keys):]:
        ax.axis('off')
    
    fig.suptitle(title)
    if save_path is None:
        safe_title = "".join(c if c.isalnum() or c in "._-" else "_" for c in title)
        save_path = safe_title + ".png"
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path)
    plt.close(fig)

def visualize_image_and_deltas(image, mean_abs_deltas, mask, title="Image and Deltas", save_path=None):
    """
    对比展示原始图像、各层模型参数平均绝对差异与 mask。
    为了在一幅图中展示“mean_abs_deltas”和“mask”的分层情况，这里采用生成子图后转换为图像的方式嵌入主图。
    对于多维数组均先降维到二维后显示。
    :param image: 原始图像 tensor
    :param mean_abs_deltas: dict，每个 key 对应一个模型参数的平均绝对差异 tensor
    :param mask: dict，每个 key 对应一个 mask tensor
    :param title: 图像标题
    :param save_path: 保存图像的完整路径。如果为 None，则根据 title 生成文件名
    """
    image_array = image.cpu().numpy().squeeze()
    if image_array.ndim > 2:
        image_array = reduce_to_2d(image_array)
    elif image_array.ndim == 1:
        image_array = image_array.reshape(1, -1)
    
    delta_keys = list(mean_abs_deltas.keys())
    if len(delta_keys) == 0:
        fig_deltas = plt.figure(figsize=(6, 4))
        plt.text(0.5, 0.5, "No delta data to display", ha="center", va="center", fontsize=12)
        plt.title("Mean Absolute Deltas")
        fig_deltas.canvas.draw()
        delta_img = np.frombuffer(fig_deltas.canvas.tostring_rgb(), dtype=np.uint8)
        delta_img = delta_img.reshape(fig_deltas.canvas.get_width_height()[::-1] + (3,))
        plt.close(fig_deltas)
    else:
        n_layers = len(delta_keys)
        n_cols = min(4, n_layers)
        n_rows = (n_layers + n_cols - 1) // n_cols
        fig_deltas, axes_deltas = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))
        if n_layers == 1:
            axes_deltas = [axes_deltas]
        else:
            axes_deltas = np.array(axes_deltas).reshape(-1)
        for ax, key in zip(axes_deltas, delta_keys):
            delta_array = mean_abs_deltas[key].cpu().numpy()
            delta_array = reduce_to_2d(delta_array)
            im = ax.imshow(delta_array, cmap='hot', interpolation='nearest')
            ax.set_title(str(key))
            ax.axis('off')
            fig_deltas.colorbar(im, ax=ax)
        for ax in axes_deltas[len(delta_keys):]:
            ax.axis('off')
        fig_deltas.suptitle("Mean Absolute Deltas")
        fig_deltas.canvas.draw()
        delta_img = np.frombuffer(fig_deltas.canvas.tostring_rgb(), dtype=np.uint8)
        delta_img = delta_img.reshape(fig_deltas.canvas.get_width_height()[::-1] + (3,))
        plt.close(fig_deltas)
    
    mask_keys = list(mask.keys())
    if len(mask_keys) == 0:
        fig_mask = plt.figure(figsize=(6, 4))
        plt.text(0.5, 0.5, "No mask data to display", ha="center", va="center", fontsize=12)
        plt.title("Mask")
        fig_mask.canvas.draw()
        mask_img = np.frombuffer(fig_mask.canvas.tostring_rgb(), dtype=np.uint8)
        mask_img = mask_img.reshape(fig_mask.canvas.get_width_height()[::-1] + (3,))
        plt.close(fig_mask)
    else:
        n_layers_mask = len(mask_keys)
        n_cols_mask = min(4, n_layers_mask)
        n_rows_mask = (n_layers_mask + n_cols_mask - 1) // n_cols_mask
        fig_mask, axes_mask = plt.subplots(n_rows_mask, n_cols_mask, figsize=(4 * n_cols_mask, 4 * n_rows_mask))
        if n_layers_mask == 1:
            axes_mask = [axes_mask]
        else:
            axes_mask = np.array(axes_mask).reshape(-1)
        for ax, key in zip(axes_mask, mask_keys):
            mask_array = mask[key].cpu().numpy()
            mask_array = reduce_to_2d(mask_array)
            im = ax.imshow(mask_array, cmap='gray', vmin=0, vmax=1, interpolation='nearest')
            ax.set_title(str(key))
            ax.axis('off')
            fig_mask.colorbar(im, ax=ax)
        for ax in axes_mask[len(mask_keys):]:
            ax.axis('off')
        fig_mask.suptitle("Mask")
        fig_mask.canvas.draw()
        mask_img = np.frombuffer(fig_mask.canvas.tostring_rgb(), dtype=np.uint8)
        mask_img = mask_img.reshape(fig_mask.canvas.get_width_height()[::-1] + (3,))
        plt.close(fig_mask)
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(image_array, cmap='gray')
    axes[0].set_title("Original Image")
    axes[0].axis('off')
    axes[1].imshow(delta_img)
    axes[1].set_title("Mean Absolute Deltas")
    axes[1].axis('off')
    axes[2].imshow(mask_img)
    axes[2].set_title("Mask")
    axes[2].axis('off')
    fig.suptitle(title)
    
    if save_path is None:
        safe_title = "".join(c if c.isalnum() or c in "._-" else "_" for c in title)
        save_path = safe_title + ".png"
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path)
    plt.close(fig)

def visualize_layer_wise(mean_abs_deltas, masks, title="Layer-wise Visualization", save_path=None):
    """
    按层为行、每类图为列的方式生成方形图。
    每一行对应一个层，每一列对应一个 ratio 的 mask。
    :param mean_abs_deltas: dict，每个 key 对应一个模型参数的平均绝对差异 tensor
    :param masks: dict，每个 key 对应一个 ratio 的 mask tensor
    :param title: 图像标题
    :param save_path: 保存图像的完整路径。如果为 None，则根据 title 生成文件名
    """
    keys = list(mean_abs_deltas.keys())
    if len(keys) == 0:
        fig = plt.figure(figsize=(6, 4))
        plt.text(0.5, 0.5, "No data to display", ha="center", va="center", fontsize=12)
        plt.title(title)
        if save_path is None:
            safe_title = "".join(c if c.isalnum() or c in "._-" else "_" for c in title)
            save_path = safe_title + ".png"
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)
        plt.close(fig)
        return

    n_layers = len(keys)
    n_ratios = len(masks)
    n_cols = n_ratios + 1
    n_rows = n_layers

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))
    if n_layers == 1:
        axes = [axes]
    else:
        axes = np.array(axes).reshape(n_layers, n_cols)

    for i, key in enumerate(keys):
        delta_array = mean_abs_deltas[key].cpu().numpy()
        delta_array = reduce_to_2d(delta_array)
        im = axes[i, 0].imshow(delta_array, cmap='hot', interpolation='nearest')
        axes[i, 0].set_title(f"{key} - Mean Absolute Deltas")
        axes[i, 0].axis('off')
        fig.colorbar(im, ax=axes[i, 0])

        for j, ratio in enumerate(sorted(masks.keys())):
            mask = masks[ratio][key].cpu().numpy()
            mask_array = reduce_to_2d(mask)
            im = axes[i, j + 1].imshow(mask_array, cmap='hot', vmin=0, vmax=1, interpolation='nearest')
            axes[i, j + 1].set_title(f"{key} - Ratio {ratio}")
            axes[i, j + 1].axis('off')
            fig.colorbar(im, ax=axes[i, j + 1])

    fig.suptitle(title)
    if save_path is None:
        safe_title = "".join(c if c.isalnum() or c in "._-" else "_" for c in title)
        save_path = safe_title + ".png"
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path)
    plt.close(fig)

def visualize_saliency_and_mask(mean_abs_deltas, masks, title="Saliency and Mask", save_path=None):
    """
    可视化 Saliency Map 和 Mask。
    :param mean_abs_deltas: 模型参数的平均绝对更新差值
    :param masks: 生成的掩码
    :param title: 图像标题
    """
    keys = list(mean_abs_deltas.keys())
    if len(keys) == 0:
        fig = plt.figure(figsize=(6, 4))
        plt.text(0.5, 0.5, "No data to display", ha="center", va="center", fontsize=12)
        plt.title(title)
        plt.show()
        return

    n_layers = len(keys)
    n_cols = min(4, n_layers)
    n_rows = (n_layers + n_cols - 1) // n_cols

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))
    if n_layers == 1:
        axes = [axes]
    else:
        axes = np.array(axes).reshape(-1)

    for ax, key in zip(axes, keys):
        # Saliency Map
        delta_array = mean_abs_deltas[key].cpu().numpy()
        im = ax.imshow(delta_array, cmap='hot', interpolation='nearest')
        ax.set_title(f"{key} - Saliency Map")
        ax.axis('off')
        fig.colorbar(im, ax=ax)

        # Mask
        mask_array = masks[0.5][key].cpu().numpy()
        im = ax.imshow(mask_array, cmap='gray', interpolation='nearest')
        ax.set_title(f"{key} - Mask")
        ax.axis('off')
        fig.colorbar(im, ax=ax)

    fig.suptitle(title)
    if save_path is None:
        safe_title = "".join(c if c.isalnum() or c in "._-" else "_" for c in title)
        save_path = safe_title + ".png"
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path)
    plt.close(fig)

def visualize_feature_maps(feature_maps, title="Feature Maps", save_path=None):
    """
    可视化特征图。
    :param feature_maps: 特征图 tensor，形状为 (64, 3, 5, 5)
    :param title: 图像标题
    """
    feature_maps = feature_maps.cpu().numpy()

    num_feature_maps = feature_maps.shape[0]

    images_per_row = 8
    num_rows = (num_feature_maps + images_per_row - 1) // images_per_row

    fig, axes = plt.subplots(num_rows, images_per_row, figsize=(16, 8))
    if num_rows == 1:
        axes = [axes]
    else:
        axes = axes.flatten()

    for i in range(num_feature_maps):
        ax = axes[i]
        feature_map = feature_maps[i, 0, :, :]
        ax.imshow(feature_map, cmap='hot', interpolation='nearest')
        ax.set_title(f"Feature Map {i+1}")
        ax.axis('off')

    for i in range(num_feature_maps, len(axes)):
        axes[i].axis('off')

    plt.suptitle(title)
    if save_path is None:
        safe_title = "".join(c if c.isalnum() or c in "._-" else "_" for c in title)
        save_path = safe_title + ".png"
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path)
    plt.close(fig)

def visualize_feature_maps_with_masks(feature_maps, masks, title="Feature Maps with Masks", save_path=None):
    """
    可视化特征图和对应的多个 mask。
    :param feature_maps: 特征图 numpy 数组，形状为 (64, 3, 5, 5)
    :param masks: mask numpy 数组，形状为 (10, 64, 3, 5, 5)，其中 10 表示 10 个 ratio
    :param title: 图像标题
    :param save_path: 保存图像的路径
    """
    num_feature_maps = feature_maps.shape[0]
    num_ratios = masks.shape[0]

    images_per_row = 8
    num_rows = (num_feature_maps + images_per_row - 1) // images_per_row

    fig, axes = plt.subplots(num_rows, images_per_row * (1 + num_ratios), figsize=(16, 8 * num_rows))
    if num_rows == 1:
        axes = [axes]
    else:
        axes = axes.flatten()

    for i in range(num_feature_maps):
        ax = axes[i * (1 + num_ratios)]
        feature_map = feature_maps[i, 0, :, :]
        ax.imshow(feature_map, cmap='hot', interpolation='nearest')
        ax.set_title(f"Feature Map {i+1}")
        ax.axis('off')

        for j in range(num_ratios):
            ax = axes[i * (1 + num_ratios) + j + 1]
            mask = masks[j, i, 0, :, :]
            ax.imshow(mask, cmap='gray', interpolation='nearest')
            ax.set_title(f"Mask {i+1} - Ratio {j+1}")
            ax.axis('off')

    for i in range(num_feature_maps * (1 + num_ratios), len(axes)):
        axes[i].axis('off')

    plt.suptitle(title)
    if save_path:
        plt.savefig(save_path, dpi=300)
    plt.close(fig)