# 这个文件，是用来提取目标模型的层内和跨层的回路，同时进行可视化

# 可直接运行的命令
# nohup python tests/test/inner_circuit_extraction.py --exp_subj 5 --exp_model_name "clip_vit-b_16" --exp_device "cuda:0" --autoencoder_name "original" --autoencoder_rate 16 --exp_layers 12 --inner_neighbors 5 --inner_min_dist 0.1 > output.log 2>&1 &
# nohup python tests/test/inner_circuit_extraction.py --exp_subj 5 --exp_model_name "dinov2" --exp_device "cuda:1" --autoencoder_name "original" --autoencoder_rate 16 --exp_layers 12 --inner_neighbors 5 --inner_min_dist 0.1 > output_1.log 2>&1 &
# nohup python tests/test/inner_circuit_extraction.py --exp_subj 5 --exp_model_name "imagenet" --exp_device "cuda:2" --autoencoder_name "original" --autoencoder_rate 16 --exp_layers 12 --inner_neighbors 3 --inner_min_dist 0.05 > output_2.log 2>&1 &
# nohup python tests/test/inner_circuit_extraction.py --exp_subj 5 --exp_model_name "mae" --exp_device "cuda:3" --autoencoder_name "original" --autoencoder_rate 16 --exp_layers 12 > output_3.log 2>&1 & 

import sys
import os

PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, PROJECT_ROOT)

import argparse
import toml
from easydict import EasyDict
from tests.sae.sae_brain_similarity.brain_guide_circuit import  inner_circuit_analysis, cross_layer_circuit_analysis
from tests.sae.sae_brain_similarity.cluster_selection_collection import inner_circuit_cluster_collection, cross_layer_circuit_cluster_collection
from src.util import get_info_from_shell

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # 这段是统一的参数
    parser.add_argument("--exp_subj", type=int, help="被试号", required=True)
    parser.add_argument("--exp_model_name", type=str, help="模型名称", required=True)
    parser.add_argument("--exp_device", type=str, help="设备", required=True)
    parser.add_argument("--autoencoder_name", type=str, help="SAE名称", required=True)
    parser.add_argument("--autoencoder_rate", type=int, help="SAE扩大的倍数", required=True)
    parser.add_argument("--exp_layers", type=int, help="模型全部层", required=True)
    parser.add_argument("--autoencoder_tied", type=str, help="SAE是否共享权重")
    parser.add_argument("--autoencoder_topk", type=int, help="模型的topk特征，当用topk方法时需要")
    # 下面是当前方法的特定参数
    parser.add_argument("--roi_name", type=str, help="脑区名称")
    parser.add_argument("--extract_topk", type=int, help="选择最相关的特征的前多少数量", default=100)
    parser.add_argument("--threshold", type=float, help="层间预设的阈值", default=None)
    parser.add_argument("--cluster_filter", type=bool, help="是否进行cluster_filter", default=True)
    parser.add_argument("--cluster_filter_topk", type=int, help="cluster_filter进行筛选保留的topk特征", default=10)
    parser.add_argument("--inner_neighbors", type=int, help="层内预设的UMAP邻居的数量", default=5)
    parser.add_argument("--inner_min_dist", type=float, help="层内预设的UMAP最小距离", default=0.1)
    parser.add_argument("--visualize_umap", type=bool, help="是否进行umap可视化", default=True)
    parser.add_argument("--inner_cluster_visualize", type=bool, help="是否进行层内聚类类别可视化", default=True)

    arg_parser = parser.parse_args()
    config_dict = toml.load('config.toml')
    args = EasyDict(config_dict)
    args = get_info_from_shell(arg_parser, args)
    # 提取最相关的
    if arg_parser.roi_name is None:
        for roi_name in ["FFA", "EBA", "RSC", "VWFA", "FOOD"]:
            inner_circuit_analysis(
                args=args,
                roi_name=roi_name, 
                subj=args.exp.subj, 
                all_layers=args.exp.layers, 
                topk=arg_parser.extract_topk, 
                threshold=arg_parser.threshold, 
                n_neighbors=arg_parser.inner_neighbors,
                min_dist=arg_parser.inner_min_dist,
                visualize=arg_parser.visualize_umap,
                inner_cluster_visualize=arg_parser.inner_cluster_visualize, 
                cluster_filter=arg_parser.cluster_filter,
                cluster_filter_topk=arg_parser.cluster_filter_topk,
                )

            for layer in range(args.exp.layers):
                inner_circuit_cluster_collection(
                    args=args,
                    roi_name=roi_name, 
                    subj=args.exp.subj, 
                    target_layer=layer, 
                    topk=arg_parser.extract_topk,
                )

    else:
        inner_circuit_analysis(
            args=args,
            roi_name=arg_parser.roi_name, 
            subj=args.exp.subj, 
            all_layers=args.exp.layers, 
            topk=arg_parser.extract_topk, 
            n_neighbors=arg_parser.inner_neighbors,
            min_dist=arg_parser.inner_min_dist,
            visualize=arg_parser.visualize_umap,
            inner_cluster_visualize=arg_parser.inner_cluster_visualize,
            threshold=arg_parser.threshold, 
            cluster_filter=arg_parser.cluster_filter, 
            cluster_filter_topk=arg_parser.cluster_filter_topk, 
        )

        for layer in range(args.exp.layers):
            inner_circuit_cluster_collection(
                args=args,
                roi_name=arg_parser.roi_name, 
                subj=args.exp.subj, 
                target_layer=layer, 
                topk=arg_parser.extract_topk,
            )

