import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List

def visualize_weights(layer_dict: Dict[str, torch.Tensor], save_dir: str, mode: str = "layerwise"):
    """
    可视化权重的分布，支持按层级和按通道绘制小提琴图。

    :param layer_dict: 包含层名称和对应权重张量的字典。
    :param save_dir: 保存可视化图像的目录路径。
    :param mode: 可视化模式，可选 "layerwise", "channelwise" 或 "both"。
    """
    # 确保保存目录存在
    os.makedirs(save_dir, exist_ok=True)

    # 按层级绘制小提琴图
    if mode in ["layerwise", "both"]:
        violin_data = []
        layer_labels = []

        # 预计算所有权重的长度
        total_length = sum(weights.numel() for weights in layer_dict.values())  # numel() 返回张量元素总数
        violin_data = np.empty(total_length, dtype=np.float32)
        layer_labels = np.empty(total_length, dtype="U256")  # 假设层名不会超过256个字符

        index = 0
        for name, weights in layer_dict.items():
            weight_values = weights.flatten().cpu().detach().numpy()
            length = len(weight_values)
            violin_data[index:index + length] = weight_values
            layer_labels[index:index + length] = name
            index += length

        data = pd.DataFrame({
            "Weight Value": violin_data,
            "Layer": layer_labels
        })

        plt.figure(figsize=(14, 8))
        sns.violinplot(x="Layer", y="Weight Value", hue="Layer", data=data, inner="quart", palette="Set2", legend=False)

        plt.title("Distribution of Weights by Layer")
        plt.xlabel("Layer Name")
        plt.ylabel("Weight Value")
        plt.xticks(rotation=90)

        save_file = os.path.join(save_dir, "weights_layerwise_violin.png")
        plt.tight_layout()
        plt.savefig(save_file)
        plt.close()

        print(f"Saved layerwise violin plot at {save_file}")

    # 按通道绘制小提琴图
    if mode in ["channelwise", "both"]:
        for name, weights in layer_dict.items():
            weight_reshape = weights.reshape(weights.shape[0], -1)  # 将权重展平
            channels = weight_reshape[:16].cpu().detach().numpy()  # 提取前 16 个通道

            violin_data = []
            channel_labels = []

            for channel_idx in range(channels.shape[0]):
                channel_data = channels[channel_idx, :]
                violin_data.extend(channel_data)
                channel_labels.extend([f"Channel {channel_idx}"] * len(channel_data))

            data = pd.DataFrame({
                "Weight Value": violin_data,
                "Channel": channel_labels
            })

            plt.figure(figsize=(14, 8))
            sns.violinplot(x="Channel", y="Weight Value", hue="Channel", data=data, inner="quart", palette="Set2", legend=False)

            plt.title(f"Distribution of Weights in Channels for {name}")
            plt.xlabel("Channel")
            plt.ylabel("Weight Value")
            plt.xticks(rotation=90)

            save_file = os.path.join(save_dir, f"weights_channelwise_{name}.png")
            plt.tight_layout()
            plt.savefig(save_file)
            plt.close()

            print(f"Saved channelwise violin plot for {name} at {save_file}")

def visualize_weights_layerwise(original_weights: List[torch.Tensor], reconstructed_weights: List[torch.Tensor], layer_names: List[str], save_file: str):
        # 确保保存路径存在
        os.makedirs(os.path.dirname(save_file), exist_ok=True)

        # 检查输入长度是否一致
        if not (len(original_weights) == len(reconstructed_weights) == len(layer_names)):
            raise ValueError("original_weights, reconstructed_weights, and layer_names must have the same length.")
        
        # 准备数据
        violin_data = []
        group_labels = []  # 区分 Original 和 Reconstructed
        layer_labels = []  # 标记每个层的名称

        for i, name in enumerate(layer_names):
            original = original_weights[i].flatten().cpu().detach().numpy()
            reconstructed = reconstructed_weights[i].flatten().cpu().detach().numpy()
            
            assert original.shape == reconstructed.shape, f"Original and Reconstructed weights have different shapes for layer {name}."
            
            violin_data.extend(original)
            group_labels.extend(["Original"] * len(original))
            layer_labels.extend([name] * len(original))
            
            violin_data.extend(reconstructed)
            group_labels.extend(["Reconstructed"] * len(reconstructed))
            layer_labels.extend([name] * len(reconstructed))
        
        # 创建一个 DataFrame 用于绘图
        
        data = pd.DataFrame({
            "Weight Value": violin_data,
            "Group": group_labels,
            "Layer": layer_labels
        })

        # 绘制小提琴图
        plt.figure(figsize=(20, 8))
        sns.violinplot(x="Layer", y="Weight Value", hue="Group", data=data, split=True, inner="quart", palette="Set2")

        # 设置图例和标题
        plt.title("Comparison of Original Weights and Reconstructed Weights")
        plt.xlabel("Layer Name")
        plt.ylabel("Weight Value")
        plt.legend(title="Group", loc="upper right")
        # 设置 x 轴标签为竖直方向
        plt.xticks(rotation=90)

        # 保存图像
        plt.tight_layout()
        plt.savefig(save_file)
        plt.close()

        print(f"Saved violin plot at {save_file}")
        

def visualize_weights_channelwise(original_weights: List[torch.Tensor], reconstructed_weights: List[torch.Tensor], layer_names: List[str], save_dir: str):
    """
    可视化 original_weights 和 reconstructed_weights 的前 100 行通道分布，使用小提琴图展示两者的差异。
    每 100 个通道绘制在一张小提琴图里，结果保存在指定文件夹中。

    :param original_weights: 包含原始权重的列表，每个元素是一个线性层的权重张量。
    :param reconstructed_weights: 包含重建权重的列表，每个元素是一个线性层的重建权重张量。
    :param layer_names: 线性层的名称列表，用于显示在小提琴图的标签上。
    :param save_dir: 保存可视化图像的目录路径。
    """
    # 确保保存目录存在
    os.makedirs(save_dir, exist_ok=True)

    # 检查输入长度是否一致
    if not (len(original_weights) == len(reconstructed_weights) == len(layer_names)):
        raise ValueError("original_weights, reconstructed_weights, and layer_names must have the same length.")
    
    for i, name in enumerate(layer_names):
        # 提取当前层的权重
        original = original_weights[i][:20].cpu().detach().numpy()  # 提取前 100 行
        reconstructed = reconstructed_weights[i][:20].cpu().detach().numpy()  # 提取前 100 行
        
        # 准备数据
        violin_data = []
        group_labels = []  # 区分 Original 和 Reconstructed
        channel_labels = []  # 标记通道编号

        for channel in range(original.shape[0]):  # 遍历每一行（通道）
            violin_data.extend(original[channel])
            group_labels.extend(["Original"] * original.shape[1])  # 列数即权重数量
            channel_labels.extend([f"Channel {channel}"] * original.shape[1])

            violin_data.extend(reconstructed[channel])
            group_labels.extend(["Reconstructed"] * reconstructed.shape[1])
            channel_labels.extend([f"Channel {channel}"] * reconstructed.shape[1])
        
        # 创建一个 DataFrame 用于绘图
        data = pd.DataFrame({
            "Weight Value": violin_data,
            "Group": group_labels,
            "Channel": channel_labels
        })

        # 绘制小提琴图
        plt.figure(figsize=(28, 8))
        sns.violinplot(x="Channel", y="Weight Value", hue="Group", data=data, split=True, inner="quart", palette="Set2")

        # 设置图例和标题
        plt.title(f"Comparison of Original and Reconstructed Weights in {name} (First 100 Channels)")
        plt.xlabel("Channel")
        plt.ylabel("Weight Value")
        plt.legend(title="Group", loc="upper right")
        # 设置 x 轴标签为竖直方向
        plt.xticks(rotation=90)

        # 保存图像
        save_file = os.path.join(save_dir, f"{name}_channelwise_violin.png")
        plt.tight_layout()
        plt.savefig(save_file)
        plt.close()

        print(f"Saved channelwise violin plot for {name} at {save_file}")
