import torch, nibabel as nib, numpy as np
from scipy.stats import ttest_1samp
from sklearn.metrics import precision_score, recall_score, f1_score
from statsmodels.stats.multitest import multipletests
import os
import json

# ---------- 0 初始化总结果容器 ----------
all_results = []
clip_name = 'large14'
# ---------- 1 数据加载 ----------
for subj in [12,5,7]:
    for attn in [False,True]:
        for topk in [100,200,300,500,1000]:
            repeat = 1
            roi  = "bodies"  # 可扩展为列表，当前固定
            general_nii = nib.load(f'Cache/floc/subj0{subj}/nsdgeneral.nii.gz')
            general_mask = (general_nii.get_fdata() >= 1)  # bool 3D

            floc_nii = nib.load(f'Cache/floc/subj0{subj}/floc-{roi}.nii.gz')
            floc_3d  = floc_nii.get_fdata()
            floc_vec = floc_3d[general_mask]  # 15724 长的 0/1 向量

            pred_beta = torch.load(
                f'./conception_localization/subj0{subj}/'
                f'{roi}_top{topk}_repeat{repeat}_{clip_name}_prompt_extra1_ae+prior_pred_fmri.pt'
            ).squeeze(1).cpu().numpy()  # (1000, 15724)

            if attn:
                attn_dir = f'./attention_roi_visualizations/subj0{subj}/bodies_top{topk}_{clip_name}_prompt_extra1'
                png_files = [f for f in os.listdir(attn_dir) if f.endswith('.png')]
                ids = [int(os.path.splitext(f)[0]) for f in png_files]
                pred_beta = pred_beta[ids]  # shape: (len(ids), 15724)

            print(f"subj{subj}, topk{topk}, attn{attn} -> pred_beta.shape:", pred_beta.shape)

            count_negative = (pred_beta < 0).sum()
            print("小于 0 的元素个数:", count_negative)

            # 1. 单样本 t 检验
            t_stats, p_values = ttest_1samp(pred_beta, popmean=0, axis=0)  # (15724,)
            # 在 t 检验后，只对正 t 值对应的 p 值做 FDR 校正
            positive_mask = t_stats > -0.1  # 只考虑正激活方向
            p_values = np.where(positive_mask, p_values, 0.1)  # 负激活的 p 值设为1（永不显著）
            
            # 2. FDR 校正
            reject, pvals_corrected, _, _ = multipletests(p_values, alpha=0.01, method='fdr_bh')

            # 3. 生成 ROI
            pred_roi = reject.astype(int)

            # 4. 计算指标
            y_true = (floc_vec > 0).astype(int)
            y_pred = pred_roi

            precision = precision_score(y_true, y_pred, zero_division=0)
            recall    = recall_score(y_true, y_pred, zero_division=0)
            f1        = f1_score(y_true, y_pred, zero_division=0)

            print(f"Precision: {precision:.4f}")
            print(f"Recall:    {recall:.4f}")
            print(f"F1:        {f1:.4f}")

            # ---------- 5 收集当前结果 ----------
            results = {
                "subj": subj,
                "roi": roi,
                "topk": topk,
                "attn": attn,
                "precision": precision,
                "recall": recall,
                "f1": f1,
                "n_pred": int(sum(pred_roi)),
                "n_true": int(sum(y_true))
            }
            all_results.append(results)

# ---------- 6 统一保存所有结果 ----------
save_dir = './attention_roi_visualizations/combined_results'
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, 'all_metrics.json')

with open(save_path, 'w') as f:
    json.dump(all_results, f, indent=4)

print(f"\n所有结果已统一保存至: {save_path}")