import argparse
import os
import re
import sys
import cv2
import numpy as np
from visualizer import get_local
get_local.activate()
import torch

# 添加上级目录到路径中，确保可以导入自定义模块
sys.path.append("submodules/gs_feature_fusion")
from segment_anything import sam_model_registry, SamGS, upscale_model_registry
from feature_output import visualize_fft_map, visualize_attention_map, visualize_attn_map
from feature_output import save_torch, save_pcd

def main():
    parser = argparse.ArgumentParser(description="SAM GS 图像处理脚本，配置数据目录")
    parser.add_argument(
        "--data_folder", 
        type=str, 
        required=True, 
        help="数据文件夹路径，例如：/home/kemove/Project/2dgsseg/data/exam/scan24/GSSeg_P/"
    )
    args = parser.parse_args()
    data_folder = args.data_folder

    # 配置 checkpoint 路径（如有需要，也可以将其设为可配置参数）
    sam_checkpoint = "./model/sam_vit_h_4b8939.pth"
    upscale_checkpoint = "./model/sam_h_upscale_weights.pth"

    # 初始化sam，加载模型
    model_type = "vit_h"
    device = "cuda"
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)
    upscale_model = upscale_model_registry[model_type](checkpoint=upscale_checkpoint)
    upscale_model.to(device=device)
    sam_gs = SamGS(sam, upscale_model)
    attn_fft_dict = {}
    attn_dict = {}
    # 每个图片做image_embedding，并存储信息到文件
    for filename in os.listdir(data_folder+ "images/"):
        file_path = os.path.join(data_folder + "images/", filename)
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        sam_gs.set_image(image)
        match = re.search(r'\d+', filename)  # 提取文件名中的数字
        if not match:
            raise ValueError(f"Cannot extract image ID from filename: {filename}")
        # 存储通道的数组
        # attn_map = sam_gs.get_pretrain_map()
        attn_map = sam_gs.get_attn_map()
        attn_dict[int(match.group())] = attn_map
        
        save_torch(attn_map, data_folder, "attention/", "attention_look/", filename)
        

    # 保存点云结果（整理了各 image 对应的信息）
    save_pcd(attn_dict, data_folder)
    print("处理完成，点云已保存。")

if __name__ == '__main__':
    main()
