import numpy as np
from PIL import Image
import torch
import matplotlib.pyplot as plt
def _visualize_fft_map(fft_map):
    # 创建一个统一的图表，准备绘制所有图像
    n = fft_map.size(0)

    # 创建一个统一的图表，准备绘制所有图像
    fig, axes = plt.subplots(n, 2, figsize=(15, 3 * n))  # n行2列的子图（幅度谱和相位谱）

    # 遍历每一行（每个样本）
    for i in range(n):
        fft_row = fft_map[i]  # 获取第 i 行 FFT 结果

        # 计算幅度谱 (Magnitude Spectrum)
        magnitude = torch.abs(fft_row)

        # 计算相位谱 (Phase Spectrum)
        phase = torch.angle(fft_row)

        # 绘制幅度谱
        axes[i, 0].plot(magnitude.numpy())  # 绘制幅度谱
        axes[i, 0].set_title(f'Magnitude Spectrum (Row {i})')
        axes[i, 0].set_xlabel('Frequency bin')
        axes[i, 0].set_ylabel('Magnitude')

        # 绘制相位谱
        axes[i, 1].plot(phase.numpy())  # 绘制相位谱
        axes[i, 1].set_title(f'Phase Spectrum (Row {i})')
        axes[i, 1].set_xlabel('Frequency bin')
        axes[i, 1].set_ylabel('Phase (radians)')
    # 添加 x 轴和 y 轴标签
    plt.xlabel("元素索引")
    plt.ylabel("值")

    # plt.title(fft_similarity)
    plt.show()
    # 调整布局以防止标签重叠
    plt.tight_layout()

    # 显示图像
    plt.show()


def visualize_fft_map(fft_map):

    # 创建一个新的图形
    plt.figure(figsize=(10, 6))
    # 绘制每一行的折线
    for i in range(fft_map.shape[0]):
        fft = fft_map[i].real
        # 计算最大值和最小值
        min_val = torch.min(fft)
        max_val = torch.max(fft)

        # 归一化
        fft_normalized = (fft - min_val) / (max_val - min_val)
        plt.plot(fft_normalized, label=f"Line {i + 1}")
    # 添加图例
    plt.legend()

    # 添加标题
    plt.title("可视化 nx32 向量")

    # 添加 x 轴和 y 轴标签
    plt.xlabel("元素索引")
    plt.ylabel("值")

    # 显示图形
    plt.show()

def visualize_attn_map(attn_map):

    # 创建一个新的图形
    plt.figure(figsize=(10, 6))
    # 绘制每一行的折线
    for i in range(attn_map.shape[0]):
        attn = attn_map[i]
        # 计算最大值和最小值
        min_val = torch.min(attn)
        max_val = torch.max(attn)

        # 归一化
        # attn_normalized = (attn_map - min_val) / (max_val - min_val)
        plt.plot(attn, label=f"Line {i + 1}")
    # 添加图例
    plt.legend()

    # 添加标题
    plt.title("可视化 nx32 向量")

    # 添加 x 轴和 y 轴标签
    plt.xlabel("元素索引")
    plt.ylabel("值")

    # 显示图形
    plt.show()

def visualize_attention_map(attention_map, name):
    """
    可视化 attention map
    """
    plt.figure(figsize=(10, 10))
    plt.imshow(attention_map, cmap='viridis')
    plt.colorbar()
    plt.title(name)
    plt.axis('off')
    plt.show()


def visualize_fft_similarity(similarity_data):
    plt.figure(figsize=(10, 6))
    plt.hist(similarity_data, bins=100, alpha=0.7)  # bins=100是设定分成100个区间
    plt.title("Histogram of Similarity Scores")
    plt.xlabel("Similarity")
    plt.ylabel("Frequency")
    plt.show()