import matplotlib.pyplot as plt
import numpy as np
import random
import torch
import os


def visualize_linear_weights(model, seq_len, pred_len, save_path):
        """可视化Linear层的归一化权重热力图"""
        try:
            weight = model.weight.data.cpu().numpy()
            np.save(os.path.join(os.path.dirname(save_path), 'linear_weights.npy'), weight)
            np.save(os.path.join(os.path.dirname(save_path), 'linear_bias.npy'), model.bias.data.cpu().numpy())
        except:
            weight = model.cpu().data.numpy()
        normalized_weight = (weight - weight.min()) / (weight.max() - weight.min())
        
        # 创建热力图
        plt.figure(figsize=(6, 5))
        plt.imshow(normalized_weight, cmap='viridis', aspect='auto', 
                extent=[0, seq_len, 0, pred_len], origin='lower')
        
        plt.xlabel('Look-back Length', fontsize=14)
        plt.ylabel('Forecast Horizon', fontsize=14)
        plt.xticks(np.arange(0, seq_len+1, 20), fontsize=12)
        plt.yticks(np.arange(0, pred_len+1, 20), fontsize=12)
        
        # 添加颜色条
        cbar = plt.colorbar()
        # cbar.set_label('Normalized Weight', rotation=270, labelpad=15)
        cbar.set_label('', rotation=270, labelpad=15)
        # plt.title('Linear Layer Weight Distribution')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()