import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import scipy
import torch.nn.functional as F
import torch.storage
from torch.utils.data import Dataset, DataLoader
from torch.utils import data
#from torchvision import transforms
from PIL import Image
import os
from collections import OrderedDict
import matplotlib.pyplot as plt
#import torchvision.models as models
import time
from pathlib import Path
# This is for the progress bar.
from pyod.models.ecod import ECOD
from pyod.models.lof import LOF
from pyod.models.deep_svdd import DeepSVDD
from pyod.models.ocsvm import OCSVM
from pyod.models.knn import KNN
from pyod.models.auto_encoder import AutoEncoder
from pyod.models.iforest import IsolationForest
from deepod.models.icl import ICL
from pyod.models.dif import DIF
from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler
import csv
import torch
import math
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support
import random
from sklearn.metrics import average_precision_score
from scipy.stats import t
from sklearn.metrics import roc_curve, roc_auc_score
#from openTSNE import TSNE
#import torchvision
from torch.utils import data
#from torchvision import transforms
import numpy as np
from torch.utils.data import Subset
import torch.nn.functional as F
from matplotlib.pyplot import figure
from torch import Tensor
#import torchvision.transforms as transforms
import numpy as np
#import torchvision.datasets as datasets
import torch
import numpy as np
import torch.utils.data as utils
from torch.utils.data import Sampler, Dataset
from model import _RealNVP
from sklearn import preprocessing
from torchmetrics.functional import auroc
from torchmetrics import AUROC,AveragePrecision


import numpy as np


def avg_relative_gap(arr):
    arr = np.asarray(arr).ravel()
    k = np.sort(arr)
    rel = (k[1:] - k[:-1]) / k[1:]
    return rel.mean()


def top_80_percent_ratio(row_norms, threshold=0.8):
    """
    输入:
        row_norms: 1D numpy 数组，每一行的范数
        threshold: 累积比例阈值，默认 0.8 (80%)
    返回:
        ratio: k / n，其中 k 为第一个使得累积和 >= threshold * 总和 的 1-based 序号
        k:     达到阈值的 1-based 序号
        sorted_norms: 从大到小排序后的范数（方便你调试或检查）
    """
    row_norms = np.asarray(row_norms)
    n = row_norms.size
    if n == 0:
        return 0.0, 0, row_norms  # 空输入的约定

    # 1. 从大到小排序
    sorted_norms = np.sort(row_norms)

    # 2. 计算累积和
    cumsum = np.cumsum(sorted_norms)
    #print(cumsum)
    total = cumsum[-1]

    if total == 0:
        # 所有范数都为 0，则累积比例永远是 0，可以约定返回 0
        return 0.0, 0, sorted_norms

    # 3. 找到第一个使得累积和 >= threshold * total 的位置 (0-based index)
    idx = np.searchsorted(cumsum, threshold * total, side='left')

    # 转为 1-based 序号 k
    k = idx + 1

    # 4. 比值 = k / n
    ratio = k / n

    return ratio, k, sorted_norms


def set_seed(seed):
    #random.seed(seed)                 # Python 随机数
    #np.random.seed(seed)              # Numpy 随机数
    torch.manual_seed(seed)           # CPU 随机数
    torch.cuda.manual_seed(seed)      # GPU 随机数
    torch.cuda.manual_seed_all(seed)  # 多 GPU 时

    # cuDNN 固定模式
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def plot_metrics_separate(AUC, LOSS):
    """
    分别绘制AUC和LOSS随epoch变化的折线图（两个子图）
    """
    # 设置全局字体大小
    plt.rcParams.update({'font.size': 14})  # 可以调整这个数值来改变字体大小
    
    epochs = range(1, len(AUC) + 1)
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 10))
    
    # 绘制AUC
    ax1.plot(epochs, AUC, color='red',  linewidth=2)
    ax1.set_xlabel('Epoch', fontsize=16)  # 单独设置x轴标签字体大小
    ax1.set_ylabel('AUC', fontsize=16)    # 单独设置y轴标签字体大小
    #ax1.set_title('AUC vs Epoch', fontsize=18, fontweight='bold')  # 设置标题字体大小
    ax1.grid(True, linestyle='--', alpha=0.7)
    ax1.tick_params(axis='both', which='major', labelsize=14)  # 设置刻度标签字体大小
    
    # 绘制LOSS
    ax2.plot(epochs, LOSS, color='blue', linewidth=2)
    ax2.set_xlabel('Epoch', fontsize=16)
    ax2.set_ylabel('LOSS', fontsize=16)
    #ax2.set_title('LOSS vs Epoch', fontsize=18, fontweight='bold')
    ax2.grid(True, linestyle='--', alpha=0.7)
    ax2.tick_params(axis='both', which='major', labelsize=14)
    
    plt.tight_layout()
    plt.show()

class CustomDataset(Dataset):
    def __init__(self,
                 X,
                 y):
        self.data=X
        self.targets=y
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return torch.from_numpy(self.data[idx]), (self.targets[idx])
def get_device():
    return 'cuda' if torch.cuda.is_available() else 'cpu'

device = get_device()

def read_gauss_data(file, normalization='z-score', train_level=1,test_level=2,seed=42):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #idx = np.random.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]

    x_train = x[train_norm_idx]
    data_dim=x_train.shape[1]
    y_train = y[train_norm_idx]
    noise_train= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(train_norm_idx), x.shape[1])
)  
    noise_test_abnorm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(anom_idx), x.shape[1])
)  
    noise_test_norm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(test_level),
    size=(len(test_norm_idx), x.shape[1])
)  
    x[anom_idx]=x[anom_idx]+noise_test_abnorm
    x[test_norm_idx]=x[test_norm_idx]+noise_test_norm
    x[train_norm_idx]=x[train_norm_idx]+noise_train
    x_train=x[train_norm_idx]
    x_test = x[np.hstack([test_norm_idx, anom_idx])]
    y_test = y[np.hstack([test_norm_idx, anom_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))
    sds=None
    # normalization
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)

    return x_train, y_train, x_test, y_test,data_dim,sds


def read_after_gauss_data(file, normalization='z-score', train_level=1,test_level=2,seed=42):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #idx = np.random.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]

    x_train=x[train_norm_idx]
    x_test = x[np.hstack([test_norm_idx, anom_idx])]
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)
    data_dim=x_train.shape[1]
    y_train = y[train_norm_idx]
    noise_train= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(train_norm_idx), x.shape[1])
)  
    noise_test_abnorm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(test_level),
    size=(len(anom_idx), x.shape[1])
)  
    noise_test_norm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(test_norm_idx), x.shape[1])
)  
    x[anom_idx]=x[anom_idx]+noise_test_abnorm
    x[test_norm_idx]=x[test_norm_idx]+noise_test_norm
    x[train_norm_idx]=x[train_norm_idx]+noise_train
    y_test = y[np.hstack([test_norm_idx, anom_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))
    #sds=None
    # normalization

    return x_train, y_train, x_test, y_test,data_dim,sds


def read_noise_data(file, normalization='z-score', seed=42,noise_rate=0):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #idx = np.random.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    ab_train=int(split*noise_rate)
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]
    train_ab_idx, test_ab_idx = anom_idx[:ab_train], anom_idx[ab_train:]
    x_train = x[np.hstack([train_norm_idx,train_ab_idx])]
    y_train = y[np.hstack([train_norm_idx, train_ab_idx])]
    data_dim=x_train.shape[1]
    x_test = x[np.hstack([test_norm_idx, test_ab_idx])]
    y_test = y[np.hstack([test_norm_idx, test_ab_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))
    sds=None
    # normalization
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)

    return x_train, y_train, x_test, y_test,data_dim,sds

def read_OD_data(file, normalization='z-score', seed=42):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]

    x_train = x
    #data_dim=x_train.shape[1]
    y_train = y

    #x_test = x[np.hstack([test_norm_idx, anom_idx])]
    #y_test = y[np.hstack([test_norm_idx, anom_idx])]

    #print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
    #      f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))

    # normalization
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        #x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
       # x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        #x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        #x_test= (x_test - mean)/(std + 1e-4)

    return x_train, y_train,x_train, y_train,x.shape[1],sds

def read_data(file, normalization='z-score', seed=42):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #idx = np.random.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]

    x_train = x[train_norm_idx]
    data_dim=x_train.shape[1]
    y_train = y[train_norm_idx]

    x_test = x[np.hstack([test_norm_idx, anom_idx])]
    y_test = y[np.hstack([test_norm_idx, anom_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))
    sds=None
    # normalization
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)

    return x_train, y_train, x_test, y_test,data_dim,sds
def read_front_data(file, normalization='z-score', seed=42):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    #rng = np.random.RandomState(seed)
    #idx = rng.permutation(np.arange(len(x)))
    #print(idx[0:10])
    #x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]

    x_train = x[train_norm_idx]
    y_train = y[train_norm_idx]
    data_dim=x_train.shape[1]

    x_test = x[np.hstack([test_norm_idx, anom_idx])]
    y_test = y[np.hstack([test_norm_idx, anom_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))

    # normalization
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)

    return x_train[:2048], y_train[:2048], x_test, y_test,data_dim,sds
def nonlinear_transform(x):
    """
    Nonlinear transformation from 2D to 100D for a tensor of input points.
    
    Args:
    - x: Input tensor of shape (N, 2), where N is the number of points.
    
    Returns:
    - y: Output tensor of shape (N, 100), where each row is the transformed vector.
    """
    # Create an empty output tensor
    N = x.shape[0]
    y = torch.zeros((N, 166))
    
    # Perform the transformation for each dimension from 1 to 100
    for i in range(1, 166):
        a_i = i * 0.1
        b_i = i * 0.05
        c_i = i * 0.01
        d_i = 1 / i
        y[:, i - 1] = (
            torch.sin(a_i * x[:, 0]**2 + b_i * x[:, 1]**3) +
            torch.cos(c_i * torch.exp(x[:, 0] * x[:, 1])) +
            torch.exp(-a_i*c_i*(x[:,0]**2+x[:,1]**2))+
            d_i * torch.log(1 + x[:, 0]**2 + x[:, 1]**2)
        )
    
    return y
def generate_unit_norm_matrix(rows, cols):
    """
    生成一个矩阵，确保每一行的2范数都等于1。
    
    Args:
    - rows: 矩阵的行数。
    - cols: 矩阵的列数。
    
    Returns:
    - matrix: 满足要求的矩阵。
    """
    # 随机生成矩阵
    matrix = np.random.randn(rows, cols)
    
    # 对每一行进行归一化处理，使得每一行的2范数等于1
    norms = np.linalg.norm(matrix, axis=1, keepdims=True)
    matrix = matrix / norms  # 每行除以对应的2范数
    
    return matrix

def min_max_normalize(x):
    filter_lst = []
    for k in range(x.shape[1]):
        s = np.unique(x[:, k])
        if len(s) <= 1:
            filter_lst.append(k)
    if len(filter_lst) > 0:
        print('remove features', filter_lst)
        x = np.delete(x, filter_lst, 1)

    scaler = MinMaxScaler()
    scaler.fit(x)
    x = scaler.transform(x)

    return x

def min_max_normalize(x):   
    filter_lst = []
    for k in range(x.shape[1]):
        s = np.unique(x[:, k])
        if len(s) <= 1:
            filter_lst.append(k)
    if len(filter_lst) > 0:
        print('remove features', filter_lst)
        x = np.delete(x, filter_lst, 1)

    scaler = MinMaxScaler()
    scaler.fit(x)
    x = scaler.transform(x)

    return x


import numpy as np

def max_gap_over_mean_row_norm(A):
    """
    A: 形状为 (n, n) 的 numpy 数组
    返回值:
        ratio: 最大相邻 gap / 行范数平均值
        max_gap: 最大相邻 gap
        gap_index: 产生最大 gap 的位置（在排序后的下标 k，
                   表示 gap = sorted_norms[k] - sorted_norms[k+1]）
    """
    # 1. 计算每一行的行范数（L2 范数）
    row_norms = A  # shape: (n,)

    if A.shape[0] < 2:
        # 只有一行时没有相邻 gap，可以按需要约定返回 0 或 np.nan
        return 0.0, 0.0, None

    # 2. 从大到小排序
    sorted_norms = np.sort(row_norms)[::-1]

    # 3. 计算相邻差值 (gap)
    gaps = sorted_norms[:-1] - sorted_norms[1:]  # shape: (n-1,)

    # 4. 最大 gap 及其位置
    max_gap_index = np.argmax(gaps)
    max_gap = gaps[max_gap_index]

    # 5. 行范数平均值
    mean_norm = row_norms.max()-row_norms.min()

    # 6. 比值
    ratio = max_gap / mean_norm if mean_norm != 0 else np.inf

    return ratio, max_gap, max_gap_index

import torch

def first_i_by_max_gap_ratio_1d(jac_sum: torch.Tensor):
    """
    jac_sum 是一维：
    1) 升序排序 a1..ad
    2) r_i=(a_{i+1}-a_i)/a_{i+1}
    3) 找 r_i 最大的 i*
    4) 返回原数组中前 i* 个元素的下标
    """
    assert jac_sum.dim() == 1, "jac_sum must be 1D"

    flat = jac_sum
    a_sorted, idx_sorted = torch.sort(flat, descending=False)  # a1..ad, idx_sorted 是原下标

    a_i = a_sorted[:-1]
    a_ip1 = a_sorted[1:]

    eps = 1e-12
    ratios = (a_ip1 - a_i) / (a_ip1 + eps)  # (d-1,)

    i_star0 = torch.argmax(ratios).item()   # 0-based 的 r_i 位置
    i_star1 = i_star0 + 1                  # 对应 1-based 的 i*

    # 前 i* 个元素（a1..ai*）在原数组中的下标
    idx_top = idx_sorted[:i_star1]         # 0-based 原下标

    return idx_top


def first_i_by_min_gap_ratio_1d(jac_sum: torch.Tensor):
    """
    jac_sum 是一维：
    1) 升序排序 a1..ad
    2) r_i=(a_{i+1}-a_i)/a_{i+1}
    3) 找 r_i 最大的 i*
    4) 返回原数组中前 i* 个元素的下标
    """
    assert jac_sum.dim() == 1, "jac_sum must be 1D"

    flat = jac_sum
    a_sorted, idx_sorted = torch.sort(flat, descending=True)  # a1..ad, idx_sorted 是原下标

    a_i = a_sorted[:-1]
    a_ip1 = a_sorted[1:]

    eps = 1e-12
    ratios = (a_ip1 - a_i) / (a_ip1 + eps)  # (d-1,)

    i_star0 = torch.argmax(ratios).item()   # 0-based 的 r_i 位置
    i_star1 = i_star0 + 1                  # 对应 1-based 的 i*

    # 前 i* 个元素（a1..ai*）在原数组中的下标
    idx_top = idx_sorted[:i_star1]         # 0-based 原下标

    return idx_top



def top_percent_positions(jac_sum: torch.Tensor, percent: float):
    """
    选出 jac_sum 中按数值从大到小的前 percent 比例元素的位置

    Returns:
        indices_nd: 形如 (k, jac_sum.dim()) 的长整型张量，每行是一个位置
        mask: 和 jac_sum 同形状的 bool mask，top percent 的位置为 True
        values: top percent 的数值
    """
    assert 0 < percent <= 1, "percent should be in (0, 1]"
    flat = jac_sum.reshape(-1)
    k = int(math.ceil(percent * flat.numel()))

    # topk 取最大的 k 个
    values, idx_flat = torch.topk(flat, k, largest=False, sorted=True)

    # unravel 成多维坐标
    indices_nd = torch.stack(torch.unravel_index(idx_flat, jac_sum.shape), dim=1)

    mask = torch.zeros_like(flat, dtype=torch.bool)
    mask[idx_flat] = True
    mask = mask.view_as(jac_sum)

    return indices_nd, mask, values


def contribution_calculation(model, train_loader, std, percent):
    model = model
    jac_sum = None
    num = 0
    #std = std.to(device).float()
    i=0

    for batch in tqdm(train_loader):
     #if i<=1:
        model = model.to(device)
        model.train()
        i=i+1

        imgs, lab = batch
        imgs = imgs.float().to(device)
        imgs.requires_grad = True

        outputs, _ = model(imgs, 0)

        num += outputs.shape[0]
        outVector = torch.sum(outputs, 0).view(-1)
        outdim = outVector.size(0)

        jac = torch.stack(
            [torch.autograd.grad(outVector[i], imgs, retain_graph=True, create_graph=False)[0]
             for i in range(outdim)],
            dim=0
        )  # (outdim, B, in_dim)
        jac = jac.permute(1, 0, 2)          # (B, outdim, in_dim)
        jac = jac @ std                    # (B, outdim, in_dim)  按你原逻辑
        jac = torch.abs(jac)

        if jac_sum is None:
            jac_sum = torch.sum(jac, dim=0)     # (outdim, in_dim)
        else:
            jac_sum += torch.sum(jac, dim=0)

    jac_sum = jac_sum / num
    jac_sum=torch.norm(jac_sum,dim=1)

    # === 取前 percent 最大的位置 ===
    #top_idx, top_mask, top_values = top_percent_positions(jac_sum, percent)
    #print(len(top_idx))
    top_idx=first_i_by_max_gap_ratio_1d(jac_sum)
    print(len(top_idx))
    print(outdim)

    return top_idx


def contribution_calculation_con(model, train_loader, std, percent):
    model = model
    jac_sum = None
    num = 0
    #std = std.to(device).float()
    i=0

    for batch in tqdm(train_loader):
     #if i<=1:
        model = model.to(device)
        model.train()
        i=i+1

        imgs, lab = batch
        imgs = imgs.float().to(device)
        imgs.requires_grad = True

        outputs, _ = model(imgs, 0)

        num += outputs.shape[0]
        outVector = torch.sum(outputs, 0).view(-1)
        outdim = outVector.size(0)

        jac = torch.stack(
            [torch.autograd.grad(outVector[i], imgs, retain_graph=True, create_graph=False)[0]
             for i in range(outdim)],
            dim=0
        )  # (outdim, B, in_dim)
        jac = jac.permute(1, 0, 2)          # (B, outdim, in_dim)
        jac = jac @ std                    # (B, outdim, in_dim)  按你原逻辑
        jac = torch.abs(jac)

        if jac_sum is None:
            jac_sum = torch.sum(jac, dim=0)     # (outdim, in_dim)
        else:
            jac_sum += torch.sum(jac, dim=0)

    jac_sum = jac_sum / num
    jac_sum=torch.norm(jac_sum,dim=1)

    # === 取前 percent 最大的位置 ===
    #top_idx, top_mask, top_values = top_percent_positions(jac_sum, percent)
    #print(len(top_idx))
    top_idx=first_i_by_min_gap_ratio_1d(jac_sum)
    a_sorted, idx_sorted = torch.sort(jac_sum, descending=True)  # a1..ad, idx_sorted 是原下标
    print(len(top_idx))
    print(outdim)

    return top_idx



def contribution_calculation_percent(model, train_loader, std, percent):
    model = model
    jac_sum = None
    num = 0
    #std = std.to(device).float()
    i=0

    for batch in tqdm(train_loader):
     #if i<=1:
        model = model.to(device)
        model.train()
        i=i+1

        imgs, lab = batch
        imgs = imgs.float().to(device)
        imgs.requires_grad = True

        outputs, _ = model(imgs, 0)

        num += outputs.shape[0]
        outVector = torch.sum(outputs, 0).view(-1)
        outdim = outVector.size(0)

        jac = torch.stack(
            [torch.autograd.grad(outVector[i], imgs, retain_graph=True, create_graph=False)[0]
             for i in range(outdim)],
            dim=0
        )  # (outdim, B, in_dim)
        jac = jac.permute(1, 0, 2)          # (B, outdim, in_dim)
        jac = jac @ std                    # (B, outdim, in_dim)  按你原逻辑
        jac = torch.abs(jac)

        if jac_sum is None:
            jac_sum = torch.sum(jac, dim=0)     # (outdim, in_dim)
        else:
            jac_sum += torch.sum(jac, dim=0)

    jac_sum = jac_sum / num
    jac_sum=torch.norm(jac_sum,dim=1)

    # === 取前 percent 最大的位置 ===
    #top_idx, top_mask, top_values = top_percent_positions(jac_sum, percent)
    #print(len(top_idx))
    top_idx=first_i_by_min_gap_ratio_1d(jac_sum)
    a_sorted, idx_sorted = torch.sort(jac_sum, descending=False)  # a1..ad, idx_sorted 是原下标
    k = max(1, int(jac_sum.numel() * percent))
    print(len(top_idx))
    print(outdim)
    return idx_sorted[:k]         


def training(model,train_loader,lr,adam=1,wd=0,pun=0.01, grad_pun=0.01,PNAL=None,std=None,opt=None):
    model = model.to(device)
    model.device = device
    model.train()
    std=std.to(device)
    std=std.float()
    #AE_pretrain(data_set=all_set,model=model,out=out,act=act,adam=adam,n_points=n_points,lr=0.001,bs=128,n_epoch=50,wd=0)
    if adam==1:
        optimizer = torch.optim.Adam(model.parameters(), lr = lr, weight_decay=wd,amsgrad=1)
    if adam==0:
        optimizer = torch.optim.SGD(model.parameters(), lr = lr,momentum=0.8)
    best_loss = 1000000
    loss_sum=0
    best_auc=0
    i=0
    je=None
    for batch in tqdm(train_loader):
     #if i<=1:
        optimizer.zero_grad()
        loss = 0.0
        rs=0
        imgs,_ = batch
        imgs = imgs.to(device)
        imgs=imgs.float()
        #imgs=imgs[:1024]
        imgs.requires_grad=True
        #kimgs=(mat@imgs.T).T
        outputs,sldj= model(imgs,sldj=0)
        log_likelihood=-0.5 * (torch.pow(outputs,2) + torch.log(torch.tensor(torch.pi*2)))
        #print(log_likelihood.shape)
        #likelihood=likelihood*model.alpha
        sample_likelihood=torch.sum(log_likelihood,dim=1)
        #print(sample_likelihood.shape)
        #likelihood=torch.sum(torch.prod(likelihood,dim=1)*model.alpha,dim=1)
        #print(likelihood.shape)
        #sample_likelihood=torch.log(likelihood)
        #log_likelihood=-0.5 * ((outputs-model.mu) ** 2 /(2*model.sigma**2)+ torch.log(2 * np.pi*model.sigma**2))#Add with jacobian to be true density
        #sample_likelihood=torch.sum(log_likelihood,dim=1)
        batch_size = outputs.size(0)
        outVector = torch.sum(outputs,0).view(-1)
        outdim = outVector.size()[0]
        if je==None:
           je=torch.zeros(outdim)
        jac=0
        F_norm=0
        L_2=1
        if grad_pun!=0:
           # begin=time.time()
            jac = torch.stack([torch.autograd.grad(outVector[i], imgs,
                                    retain_graph=True, create_graph=True)[0] for i in range(outdim)], dim=0)#last 784 is the input dimension
            #print(jac.shape)
            #dets=torch.stack([torch.det(jac[i]) for i in range(jac.shape[0])])
            #print(abs(dets).sum())
            #print(" Contribution mean to each z:"+str(torch.sum(abs(jac))))
            #jac=
            #jac=torch.pow(torch.sum(abs(jac),dim=2),2)
           # print(torch.mean(torch.sum(jac,dim=0)))
            #print(jac.shape)
            #jac=torch.sqrt(torch.sum(torch.pow(torch.sum(abs(jac),dim=2),2),dim=1))
            #print(jac.shape)
            #je+=jac.detach().cpu()
            #print(jac.shape)
            #jac=torch.pow(jac,2)
            #print(jac.shape)
            #jac_norm=torch.norm(jac,p=2,dim=2)
            #print(jac_norm<=0)
            #jac_norm=torch.mean(jac_norm,dim=1)
            #print(jac_norm.shape)
            #print(jac.shape)
            jac=jac.permute(1,0,2)
            jac=torch.matmul(jac,std)
            #print(jac.shape)
            jac=torch.abs(jac)
            jac=torch.mean(jac,dim=0)
            #print(jac.shape)
            #jac=torch.abs(jac)
            #jac=torch.abs(jac)
            if PNAL=='L_2':
             #print(" 1")
             #jac=torch.pow(jac,2)
             #jac=torch.sum(torch.sqrt(torch.sum(jac,dim=1)))
             jac_norm=torch.norm(jac,dim=1)
             jac=torch.sum(jac_norm)
             #print(jac_norm.shape)
             #L_2=1
            elif PNAL=='L_1sq':
             #print(" 2")
             #jac=torch.pow(jac,2)
             jac=torch.sum(torch.sqrt(torch.sum(jac,dim=1)))
            elif PNAL=='L2cL1':
             #jac=1/(jac+1)
             #print(" 3")
             jac=torch.pow(jac,2)
             L_2=torch.sqrt(torch.sum(jac))
             jac=torch.sum(torch.sqrt(torch.sum(jac,dim=1)))
             jac=jac/L_2
            elif PNAL=='L2-L1':
             #jac=1/(jac+1)
             #print(" 4")
             jac=torch.pow(jac,2)
             L_2=torch.sqrt(torch.sum(jac))
             jac=torch.sum(torch.sqrt(torch.sum(jac,dim=1)))
             jac=jac-L_2
            elif PNAL=='overL2cL1':
             #print(" 5")
             jac=1/(jac+1)
             jac=torch.pow(jac,2)
             L_2=torch.sqrt(torch.sum(jac))
             jac=torch.sum(torch.sqrt(torch.sum(jac,dim=1)))
             jac=jac/L_2
            elif PNAL=='overL2-L1':
             #print(" 6")
             jac=1/(jac+1)
             jac=torch.pow(jac,2)
             L_2=torch.sqrt(torch.sum(jac))
             jac=torch.sum(torch.sqrt(torch.sum(jac,dim=1)))
             jac=jac-L_2
            elif PNAL=='half/half':
             #print("7")
             row_norms = torch.norm(jac, p=2, dim=1)
             sorted_indices = torch.argsort(row_norms)
             half_size = len(row_norms) // 2
             min_half_indices = sorted_indices[:half_size]  
             max_half_indices = sorted_indices[half_size:] 
             min_half_sum = row_norms[min_half_indices].sum()
             max_half_square_sum = (row_norms[max_half_indices] ** 2).sum()
             jac=min_half_sum/max_half_square_sum
            elif PNAL=='half-half':
             #print(" 8")
             row_norms = torch.norm(jac, p=2, dim=1)
             sorted_indices = torch.argsort(row_norms)
             half_size = len(row_norms) // 2
             min_half_indices = sorted_indices[:half_size]  
             max_half_indices = sorted_indices[half_size:] 
             min_half_sum = row_norms[min_half_indices].sum()
             max_half_square_sum = (row_norms[max_half_indices] ** 2).sum()
             jac=min_half_sum-max_half_square_sum
             #jac=torch.sum(torch.sqrt(jac))
            #print(jac.shape)
            #jac=torch.sum(jac)
            #jac=torch.sum(jac)-F_norm
        loss=torch.mean(-(sample_likelihood+sldj))
        #print(sldj<=0)
        #print(jac)
        #print(jac)
        #print(sample_likelihood.max())
        #print("Sample density mean: "+str(sample_likelihood.mean()))
        #print("Determinants mean: "+str(sldj.mean()))
        #jac=jac/L_2
        loss+=(jac)*grad_pun
        #print(jac) 
       # print(jac/L_2)
        loss.backward()
        optimizer.step()
        #break
        #print(model.mu)
        i+=1
        #print(trac.shape)
    #loss_sum=loss_sum/i/1024
     #if auc>best_auc:
     #   best_auc=auc
    #je=torch.stack(je)
    #print(je.shape)
    print(f" Train |  loss = {loss:.5f}  grad={jac:.5f}")
   # print(model.mu)
   # print(model.sigma)
    #print(model.alpha)
    return je,loss.detach().cpu()

import torch
from tqdm import tqdm

def build_contributed_testset(model, test_loader, top_idx, device='cuda'):
    """
    Args:
        model: flow model with forward() and reverse()
        test_loader: DataLoader yielding (x, y)
        top_idx: (k, 2) 或 (k, ...) 的位置索引，只保留这些位置的 z
        CustomDataset: 你的 dataset 类
        device: 'cuda' or 'cpu'

    Returns:
        new_test_set: CustomDataset(new_test_data, new_test_lab)
        new_test_data: Tensor
        new_test_lab: Tensor
    """
    model = model.to(device)
    model.eval()

    new_x_list = []
    new_y_list = []

    # 预处理 top_idx，确保在 CPU 上的 long
    if not torch.is_tensor(top_idx):
        top_idx = torch.tensor(top_idx, dtype=torch.long)
    top_idx = top_idx.long()

    with torch.no_grad():
        for batch in tqdm(test_loader):
            x, y = batch
            x = x.to(device).float()
            y = y.to(device)

            # forward 得到 z
            outputs, sldj = model(x, sldj=0)   # outputs: (B, outdim) or (B, outdim, in_dim?) depends on your model

            # === 构造 mask 并置零 ===
            # 这里假设 outputs 的 shape 是 (B, D) 或 (B, outdim, in_dim) 中的后两维对应 jac_sum 的 shape
            z = outputs
            #print(z.shape)

            # 如果 z 是 (B, D)，那 top_idx 应该是一维索引；
            # 你现在 top_idx 是二维 (row, col)，说明 z 更可能是 (B, outdim, in_dim)
            if z.dim() == 2:
                # top_idx 需要是一维 flat index
                # 如果你确实只有二维 top_idx，那就 flatten 一下
                flat_size = z.size(1)
                flat_idx = top_idx.view(-1)
                mask = torch.zeros_like(z)
                mask[:, flat_idx] = 1.0
                #print(mask[0])
                #print(mask[1])
                z_masked = z * mask
            else:
                # z shape: (B, outdim, in_dim, ...)
                mask = torch.zeros_like(z)

                # top_idx 形如 (k, 2) -> (row, col)
                # 只在除 batch 维外的前两维打点
                rows = top_idx[:, 0]
                cols = top_idx[:, 1]
                mask[:, rows, cols] = 1.0

                z_masked = z * mask

            # === reverse 得到新的 x ===
            x_new, _ = model.reverse(z_masked, sldj=0)

            new_x_list.append(x_new.detach().cpu())
            new_y_list.append(y.detach().cpu())

    new_test_data = torch.cat(new_x_list, dim=0).numpy()
    new_test_lab  = torch.cat(new_y_list, dim=0).numpy()

    new_test_set = CustomDataset(new_test_data, new_test_lab)
    return new_test_set, new_test_data, new_test_lab


def testing(model,test_loader): 
 model = model
 device = 'cuda'
 model = model.to(device)
 model.eval()
 preds,targets,contributed_preds,sldj_preds,pz_preds,train_preds,test_preds= [],[],[],[],[],[],[]

    
 for batch in test_loader:
    x, y = batch
    x, y = x.to(device), y.to(device)
    x = x.to(device)
    x=x.float()
    x.requires_grad=False
    outputs,sldj= model(x,sldj=0)
    test_preds+=outputs.detach().cpu()
    #likelihood=1
    log_likelihood=-0.5 * (torch.pow(outputs,2))#Add with jacobian to be true density
    #likelihood=torch.exp(-(outputs.unsqueeze(-1)-model.mu)**2/(2*model.sigma**2)/torch.sqrt(2*np.pi*model.sigma**2))*(model.alpha/torch.sum(model.alpha))
    #likelihood=torch.sum(likelihood,dim=2)
    #sample_likelihood=torch.sum(torch.log(likelihood),dim=1)
    #contributed_sample_likelihood
    #sample_likelihood=torch.log(likelihood)
    #contributed_log_likelihood=log_likelihood*(jac_sum)
    #contributed_sample_likelihood=torch.sum(contributed_log_likelihood,dim=1)
    sample_likelihood=torch.sum(log_likelihood,dim=1)
    pred=-(sample_likelihood+sldj)
    #pred=-sldj
    #pred=-(sample_likelihood)
    #pred[torch.isinf(pred)] =torch.min(pred)
    #sldj_preds+=-sldj.detach()
    #pz_preds+=-sample_likelihood.detach()
    preds+=pred
    targets += y
 i = 0
 auroc = AUROC(task="binary")
 average_precision = AveragePrecision(task="binary")
 targets=torch.stack(targets)
 preds=torch.stack(preds)
 roc_auc = auroc(preds, targets)
 ap_score = average_precision(preds, targets)
 
 return roc_auc,ap_score


def draw_fig(jac_sum,epoch):
# 假设你有一个 512 维的 tensor
  data = jac_sum
# 将 tensor 的数据转换为 numpy 并排序
  sorted_data = data
  sorted_data = torch.sort(data).values.numpy()
  #print(sorted_data.min())
# 绘制柱状图
  plt.figure(figsize=(20, 12))
  plt.bar(range(jac_sum.shape[0]), sorted_data)
  plt.xlabel('Index')
  plt.ylabel('Value')
  plt.title('Bar Plot of Sorted Tensor Values')
  #plt.show()
  plt.savefig('./jac_figs/'+str(epoch)+" .png", format='png')
  plt.close()
#torch.manual_seed(42)
#torch.cuda.manual_seed(42)

def draw_3D_fig(x,labels,epoch):
 fig = plt.figure()

# 创建一个 3D 坐标轴
 ax = fig.add_subplot(111, projection='3d')
 list=['normal data','noised normal data','anomalous data']
# 为每个标签分配不同的颜色
 unique_labels = np.unique(labels)  # 获取标签的唯一值
 colors = plt.cm.viridis(np.linspace(0, 1, len(unique_labels)))  # 使用viridis色图为每个标签分配一个颜色

    # 为每个标签绘制点并添加标签
 for i, label in enumerate(unique_labels):
        # 选择对应标签的点
        label_points = x[labels == label]
        ax.scatter(label_points[:, 0], label_points[:, 1], label_points[:, 2], s=5, color=colors[i], label=list[i],alpha=0.9)

# 设置坐标轴标签
 ax.set_xlabel('X')
 ax.set_ylabel('Y')
 ax.set_zlabel('Z')

# 添加颜色条
 ax.legend()

# 显示图形
 plt.show()
 #plt.savefig('./NVP_3D/'+str(epoch)+" .png", format='png')
 plt.close()
def draw_minmax_fig(jac_sum,epoch):
# 假设你有一个 512 维的 tensor
  tensor = jac_sum
  print(tensor.shape)
# 将 tensor 的数据转换为 numpy 并排序
  #sorted_data = data
  data_np = tensor.numpy()
  #data_np[10][100]=1000
  fig, ax = plt.subplots(figsize=(10, 10), dpi=1000)  # 正方形尺寸
  heatmap = ax.pcolormesh(data_np, cmap='coolwarm', shading='auto')
# 添加 colorbar，并调整其位置，使得主图保持正方形
  cbar = plt.colorbar(heatmap, ax=ax, fraction=0.046, pad=0.04)
# 显示图形
  plt.title("Square Heatmap without Colorbar")
  #plt.show()
  plt.savefig('./MIN-MAX-figs/'+str(epoch)+" .png", format='png')
  plt.close()
def knn_average_distance(train_data, test_data, k):
    """
    计算每个测试样本到其k个最近邻的平均距离。

    参数:
    train_data: torch.Tensor, 训练数据 (形状: num_train_samples x num_features)
    test_data: torch.Tensor, 测试数据 (形状: num_test_samples x num_features)
    k: int, 最近邻的数量

    返回:
    avg_distances: torch.Tensor, 每个测试样本到其k个最近邻的平均距离 (形状: num_test_samples)
    """
    # 计算测试数据到训练数据的欧氏距离矩阵
    distances = torch.cdist(test_data, train_data, p=2)  # 形状: [num_test_samples, num_train_samples]

    # 对每个测试样本找到距离最小的k个值
    k_nearest_distances, _ = torch.topk(distances, k, dim=1, largest=False)

    # 计算平均距离
    avg_distances = k_nearest_distances.mean(dim=1)

    return avg_distances


import torch

def verify_reverse_correctness(model, data_loader, device='cuda', atol=1e-5, rtol=1e-5, max_batches=3):
    """
    验证 model.reverse 是否为 forward 的正确逆映射

    Args:
        model: your _RealNVP
        data_loader: yields (x, y) or just x
        device: cuda/cpu
        atol, rtol: torch.allclose tolerance
        max_batches: 只测前几个 batch，够用且省时间

    Prints:
        reconstruction error stats
        optional sldj consistency stats
    """
    model = model.to(device)
    model.eval()

    recon_errors = []
    sldj_errors = []

    with torch.no_grad():
        for bi, batch in enumerate(data_loader):
            if bi >= max_batches:
                break

            # 兼容 (x,y) 或 x
            if isinstance(batch, (list, tuple)):
                x = batch[0]
            else:
                x = batch

            x = x.to(device).float()

            # forward
            z, sldj_fwd = model(x, sldj=0)

            # reverse
            x_rec, sldj_inv = model.reverse(z, sldj=0)

            # reconstruction error
            err = (x_rec - x).abs()
            recon_errors.append(err.mean().item())

            # sldj consistency（如果 reverse 里也更新 sldj）
            if sldj_fwd is not None and sldj_inv is not None:
                s_err = (sldj_inv + sldj_fwd).abs().mean().item()
                sldj_errors.append(s_err)

            # allclose check (per batch)
            ok = torch.allclose(x_rec, x, atol=atol, rtol=rtol)
            print(f"[batch {bi}] allclose(x_rec, x) = {ok}, mean|err| = {err.mean().item():.3e}, max|err| = {err.max().item():.3e}")

    print("\n=== Summary ===")
    print(f"Avg mean|x_rec-x| over {len(recon_errors)} batches: {sum(recon_errors)/len(recon_errors):.3e}")
    if len(sldj_errors) > 0:
        print(f"Avg mean|sldj_inv + sldj_fwd|: {sum(sldj_errors)/len(sldj_errors):.3e}")


def read_noise_data(file, normalization='z-score', seed=42,noise_rate=0):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #idx = np.random.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    ab_train=int(split*noise_rate)
    max_ab_train=int(split*0.1)
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]
    train_ab_idx, test_ab_idx = anom_idx[:ab_train], anom_idx[max_ab_train:]
    x_train = x[np.hstack([train_norm_idx,train_ab_idx])]
    y_train = y[np.hstack([train_norm_idx, train_ab_idx])]
    data_dim=x_train.shape[1]
    x_test = x[np.hstack([test_norm_idx, test_ab_idx])]
    y_test = y[np.hstack([test_norm_idx, test_ab_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))
    sds=None
    # normalization
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)

    return x_train, y_train, x_test, y_test,data_dim,sds



def read_after_gauss_data(file, normalization='z-score', train_level=1,test_level=2,seed=42):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #idx = np.random.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]
    x_train=x[train_norm_idx]
    x_test_norm= x[test_norm_idx]
    x_test_ab=x[anom_idx]
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test_norm = np.array([(xx - mus) / sds for xx in x_test_norm])
        x_test_ab = np.array([(xx - mus) / sds for xx in x_test_ab])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)
    data_dim=x_train.shape[1]
    y_train = y[train_norm_idx]
    noise_train= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(train_norm_idx), x.shape[1])
)  
    noise_test_abnorm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(anom_idx), x.shape[1])
)  
    noise_test_norm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(test_level),
    size=(len(test_norm_idx), x.shape[1])
)
    print(noise_test_norm)
    x_test_ab =x_test_ab +noise_test_abnorm
    x_test_norm=x_test_norm+noise_test_norm
    x_train=x_train+noise_train
    #x_train=x[train_norm_idx]
    x_test = np.concat([x_test_norm, x_test_ab])
    y_test = y[np.hstack([test_norm_idx, anom_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))
    #sds=None
    # normalization

    return x_train, y_train, x_test, y_test,data_dim,sds


def read_after_sgauss_data(file, normalization='z-score', train_level=1,test_level=2,seed=42):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #idx = np.random.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]
    x_train=x[train_norm_idx]
    x_test_norm= x[test_norm_idx]
    x_test_ab=x[anom_idx]
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test_norm = np.array([(xx - mus) / sds for xx in x_test_norm])
        x_test_ab = np.array([(xx - mus) / sds for xx in x_test_ab])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)
    data_dim=x_train.shape[1]
    y_train = y[train_norm_idx]
    noise_train= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(train_norm_idx), x.shape[1])
)  
    noise_test_abnorm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(test_level),
    size=(len(anom_idx), x.shape[1])
)  
    noise_test_norm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(test_norm_idx), x.shape[1])
)
    print(noise_test_norm)
    x_test_ab =x_test_ab +noise_test_abnorm
    x_test_norm=x_test_norm+noise_test_norm
    x_train=x_train+noise_train
    #x_train=x[train_norm_idx]
    x_test = np.concat([x_test_norm, x_test_ab])
    y_test = y[np.hstack([test_norm_idx, anom_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))
    #sds=None
    # normalization

    return x_train, y_train, x_test, y_test,data_dim,sds



act=2
bs=512#512512
epoch=100
mid_dim=2048
adam=1#AMAZON:59.1 no regularizer
pow=1
PNAL='L_1sq'#L_2,L_1sq,L2-L1,L2cL1,overL2-L1,overL2cL1,half-half,half/half
learn=False
mu=torch.zeros(1)
#set_seed(4)
sigma=torch.ones(1)
#data='30_satellite.npz'#bank should run  satellite:84  waveform:87
folder_path = Path('./datasets_small_eye')
j=0
#for file_path in folder_path.rglob('*'):  # 匹配所有文件或文件夹
#for k in range (0,1):
#for file_path in ["datasets/30_satellite.npz","datasets/7_Cardiotocography.npz","datasets/29_Pima.npz", "datasets/35_SpamBase.npz"]:
for file_path in [#"datasets/4_breastw.npz"
                  "datasets/7_Cardiotocography.npz",
                  #"datasets/22_magic.gamma.npz",
                  #"datasets/29_Pima.npz",
                  #"datasets/30_satellite.npz",
                  #"datasets/35_SpamBase.npz",
                  #"datasets/46_WPBC.npz",
                  #"datasets/47_yeast.npz",
                  #"datasets/19_landsat.npz",
                  #"datasets/18_Ionosphere.npz"
#"datasets/36_speech.npz"
                  ]:
 #file_path="datasets/34_smtp.npz"
 #if "Ionosphere" in file_path:
 #   j=1
 #if j==0:
 #   continue
 #file_path="datasets/24_mnist.npz"
 #first_slash_idx = file_path.find('\\')
 file_path=str(file_path)
 first_slash_idx = file_path.find('/')
 print("first_slash_idx")
 print(first_slash_idx)
 fit_ratio=0
 dot_idx = file_path.rfind('.')
 data = file_path[first_slash_idx + 1 : dot_idx]
 percent=20
 for k in [1]:
  for lr in [0.005]:
   for grad_pun in [0.1]:#1#63.6
            avg_auroc,avg_auprc,avg_con_auroc,avg_con_auprc,avg_f1,avg_con_f1=[],[],[],[],[],[]
            for i in range (0,5):
                losses,aucs=[],[]
                #z=torch.empty(5000).uniform_(-1, 1)
                #x1=z*(1/torch.sqrt(z.var(unbiased=False)))
                #x2=z*2*(1/torch.sqrt((z*2).var(unbiased=False)))
                #x3=z**3*(1/torch.sqrt((z**3).var(unbiased=False)))
                #train_data=torch.stack([x1,x2,x3]).T
                #e=torch.normal(0,0.001,size=[5000,100])
                #train_data=(train_data+e).numpy()
                #train_data,train_lab,test_data,test_lab,Input_dim,std=read_after_gauss_data(file_path,normalization='z-score',train_level=0.1,test_level=0.2)
                #train_data,train_lab,test_data,test_lab,Input_dim,std=read_after_gauss_data(file_path,normalization='z-score',train_level=0.1,test_level=0.2)#read_gauss_data(file_path,normalization='z-score',train_level=0.1,test_level=0.2)#read_data(file_path,normalization='z-score')
                train_data,train_lab,test_data,test_lab,Input_dim,std=read_data(file_path,normalization='z-score')
                train_set,test_set=CustomDataset(train_data, train_lab), CustomDataset(test_data, test_lab)
                std=torch.diag(torch.tensor(1/std)).cuda().float()
                std=torch.eye(Input_dim).cuda()
                print(std.shape)
                model=_RealNVP(input_dim=Input_dim,
                 mid_dim=mid_dim, 
                 masktype=0,
                 act=act,
                 mu=mu,
                 sigma=sigma,
                 learn=learn
             )
                best_auc,best_con_auc,best_sldj_auc,best_pz_auc,best_auprc,best_con_auprc,best_f1,best_con_f1=0,0,0,0,0,0,0,0
                o_auc,o_prc=0,0
                time1=time.time()
                for n_epoch in range(epoch):
                 print(f"[EPOCH: {n_epoch:.1f}]")
                 torch.cuda.empty_cache()
                #std=torch.eye(Input_dim)
                #train_data,train_lab,test_data,test_lab,Input_dim=read_front_data('datasets/'+data,normalization='z-score')
                 #random_indices = np.random.choice(train_data.shape[0], size=2048, replace=False)
                 #train_set,test_set=CustomDataset(train_data[random_indices, :], np.ones(random_indices.shape)), CustomDataset(test_data, test_lab)
                 #train_set,test_set=CustomDataset(train_data, train_lab), CustomDataset(test_data, test_lab)
                 train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=bs,
        shuffle=True,
        num_workers=0
    )
                 test_loader = torch.utils.data.DataLoader(
        dataset=test_set, 
        batch_size=2048,
        shuffle=True,
        num_workers=0
    )           
                 con_loader = torch.utils.data.DataLoader(
        dataset=train_set, 
        batch_size=bs,
        shuffle=True,
        num_workers=0
    )           
                 je,loss=training(
                        model=model,
                        train_loader=train_loader,
                        lr=lr,
                        adam=adam,
                        wd=0,
                        pun=0,
                        grad_pun=grad_pun,
                        PNAL=PNAL,
                        std=std      
                        )
                 #losses.append(loss.detach().cpu())
                 #verify_reverse_correctness(model, train_loader, device='cuda', atol=1e-5, rtol=1e-5, max_batches=5)
                 #top_idx=contribution_calculation_con(model,con_loader,std,percent=percent)
                 top_idx=contribution_calculation(model,con_loader,std,percent=None)
                 #top_idx=contribution_calculation_percent(model,con_loader,std,percent=1)
                 print("SOURCE IDX: "+str(len(top_idx)))
                 time2=time.time()
                 #auc,con_auc,sldj_auc,auprc,con_auprc,f1,con_f1=0,0,0,0,0,0,0
                 new_test_set,test_data,test_lab=build_contributed_testset(model,test_loader,top_idx)
                 #new_train_set,train_data,train_lab=build_contributed_testset(model,test_loader,top_idx)
                 new_test_loader = torch.utils.data.DataLoader(
        dataset=new_test_set, 
        batch_size=2048,
        shuffle=True,
        num_workers=0
    )           
                 auc,prc=testing(model,new_test_loader)
                 losses.append(loss)
                 aucs.append(auc.detach().cpu())

                 best_auc=max(auc,best_auc)
                 best_auprc=max(best_auprc,prc)
                 #print(f"AUROC:  {best_auc:.3f}, AUPRC: {best_auprc}")
                print("###TIME: "+str((time2-time1)*100))
                #plot_metrics_separate(aucs,losses)
                avg_auroc.append(best_auc)
                avg_auprc.append(best_auprc)
            avg_auroc=torch.mean(torch.stack(avg_auroc))
            avg_auprc=torch.mean(torch.stack(avg_auprc))
            print(f"AVG AUROC:  {avg_auroc:.3f}, AVG AUPRC: {avg_auprc}")
            #with open('./ABL/100/'+data+" lr- "+str(lr)+" "+PNAL+' pun- '+str(grad_pun)+'.txt', 'a+') as f:
            #    f.write("best_auroc: "+str(torch.mean(avg_auroc))+" std: "+str(torch.std(avg_auroc))+" best_auprc: "+str(torch.mean(avg_auprc))+"std: "+str(torch.std(avg_auprc))+'\n')
            #with open('./OD-RESULT/FLOW/'+data+" lr- "+str(lr)+" "+PNAL+' pun- '+str(grad_pun)+'.txt', 'a+') as f:
            #with open('./NOISY-FLOW-tabular-RESULT/'+data+str()+'.txt', 'a+') as f:
            #    f.write("best_auroc: "+str(torch.mean(avg_auroc))+" std: "+str(torch.std(avg_auroc))+" best_auprc: "+str(torch.mean(avg_auprc))+"std: "+str(torch.std(avg_auprc))+" best_f1: "+str(np.mean(avg_f1))+"std: "+str(np.std(avg_f1))+'\n')
            #with open('./NOISY-CON-FLOW-tabular-RESULT/'+data+str()+'.txt', 'a+') as f:
            #    f.write("best_auroc: "+str(torch.mean(avg_con_auroc))+" std: "+str(torch.std(avg_con_auroc))+" best_auprc: "+str(torch.mean(avg_con_auprc))+"std: "+str(torch.std(avg_con_auprc))+" best_f1: "+str(np.mean(avg_con_f1))+"std: "+str(np.std(avg_con_f1))+'\n')