
from PIL import Image
import torch
import os
import matplotlib.pyplot as plt
import cv2
import torch.nn as nn

def save_weight(sam):
    upscale_weights = {}

    for name, param in sam.named_parameters():
        if "up" in name:  # 筛选出和 "up" 相关的模块
            upscale_weights[name] = param.data
            print(f"提取的权重: {name} -> 形状: {param.shape}")
    new_checkpoint_path = "sam_h_upscale_weights.pth"
    torch.save(upscale_weights, new_checkpoint_path)
    print(f"已成功将权重保存到 {new_checkpoint_path}")



def save_2txt(attention, data_folder, output_folder, filename):
    output = []
    os.makedirs(data_folder + output_folder, exist_ok=True)
    for y in range(256):
        for x in range(256):
            pixel_vector = attention[:, y, x]  # Frequency domain values for the pixel
            output.append({
                "p": (y, x),
                "f": pixel_vector.tolist()  # Convert to a list for serialization
            })
    # Save results to a text file
    output_folder = data_folder + output_folder
    filename = os.path.splitext(filename)[0] + ".txt"
    output_path = os.path.join(output_folder, filename)
    with open(output_path, "w") as file:
        for item in output:
            file.write(f"P: {item['p']}, F: {item['f']}\n")
    print(f"Results saved to: {output_path}")
    return output

def save_imgsize_attn_map(attention_map, data_folder, output_folder, filename, i):
    output_name = os.path.splitext(filename)[0] + ".JPG"
    os.makedirs(data_folder + output_folder, exist_ok=True)
    output_folder = data_folder + output_folder + "a" + str(i)
    os.makedirs(output_folder, exist_ok=True)
    output_path = os.path.join(output_folder, output_name)
    # result = cv2.bitwise_and(attention_map, attention_map, mask=mask)
    Image.fromarray(attention_map).save(output_path)


def save_torch(attention, data_folder, output_folder, output_folder_look, filename):
    filename = os.path.splitext(filename)[0]+ ".pt"
    filename_look = os.path.splitext(filename)[0]+ ".png"
    output_folder = data_folder + output_folder
    os.makedirs(output_folder, exist_ok=True)
    output_folder_look = data_folder + output_folder_look
    output_path = os.path.join(output_folder, filename)
    output_path_look = os.path.join(output_folder_look, filename_look)
    os.makedirs(output_folder_look, exist_ok=True)
    
    _min = attention.min()  # 形状 1×1×C
    _max = attention.max()  # 形状 1×1×C
    # 归一化（广播机制自动对齐维度）
    attn_map_normalized = (attention - _min) / (_max - _min + 1e-8)
    # attn_map_normalized = (attention - _min)
    
    torch.save(attn_map_normalized, output_path)
    # 假设输入张量 tensor_32d 的 shape 是 [32, 891, 1193]
    selected_indices = [18, 28, 31]  # 替换为你的目标通道索引


    # # 提取 RGB 通道并调整维度
    att_numpy =  attn_map_normalized[:, :, selected_indices].detach().cpu().numpy()# 归一化到 [0, 1]
    np_image_bgr = cv2.cvtColor(att_numpy * 255, cv2.COLOR_RGB2BGR)
    cv2.imwrite(output_path_look, np_image_bgr)

    # # 转换为 NumPy 并显示
    # plt.imshow(rgb_normalized.cpu().numpy())
    # plt.axis('off')
    # plt.show()
    
    print(f"Tensor saved to {output_path}")