import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import scipy
import torch.nn.functional as F
import torch.storage
from torch.utils.data import Dataset, DataLoader
from torch.utils import data
#from torchvision import transforms
from PIL import Image
import os
from collections import OrderedDict
import matplotlib.pyplot as plt
#import torchvision.models as models
import time
from pathlib import Path
# This is for the progress bar.
from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler
import gc
import csv
import torch
import math
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support
import random
from sklearn.metrics import average_precision_score
from scipy.stats import t
from sklearn.metrics import roc_curve, roc_auc_score
#from openTSNE import TSNE
#import torchvision
from torch.utils import data
#from torchvision import transforms
import numpy as np
from torch.utils.data import Subset
import torch.nn.functional as F
from matplotlib.pyplot import figure
from torch import Tensor
#import torchvision.transforms as transforms
import numpy as np
#import torchvision.datasets as datasets
import torch
import numpy as np
import torch.utils.data as utils
from torch.utils.data import Sampler, Dataset
from model import _RealNVP
from sklearn import preprocessing
from torchmetrics.functional import auroc
from torchmetrics import AUROC,AveragePrecision


import numpy as np


def avg_relative_gap(arr):
    arr = np.asarray(arr).ravel()
    k = np.sort(arr)
    rel = (k[1:] - k[:-1]) / k[1:]
    return rel.mean()


def top_80_percent_ratio(row_norms, threshold=0.8):
    """
    输入:
        row_norms: 1D numpy 数组，每一行的范数
        threshold: 累积比例阈值，默认 0.8 (80%)
    返回:
        ratio: k / n，其中 k 为第一个使得累积和 >= threshold * 总和 的 1-based 序号
        k:     达到阈值的 1-based 序号
        sorted_norms: 从大到小排序后的范数（方便你调试或检查）
    """
    row_norms = np.asarray(row_norms)
    n = row_norms.size
    if n == 0:
        return 0.0, 0, row_norms  # 空输入的约定

    # 1. 从大到小排序
    sorted_norms = np.sort(row_norms)

    # 2. 计算累积和
    cumsum = np.cumsum(sorted_norms)
    #print(cumsum)
    total = cumsum[-1]

    if total == 0:
        # 所有范数都为 0，则累积比例永远是 0，可以约定返回 0
        return 0.0, 0, sorted_norms

    # 3. 找到第一个使得累积和 >= threshold * total 的位置 (0-based index)
    idx = np.searchsorted(cumsum, threshold * total, side='left')

    # 转为 1-based 序号 k
    k = idx + 1

    # 4. 比值 = k / n
    ratio = k / n

    return ratio, k, sorted_norms


def set_seed(seed):
    #random.seed(seed)                 # Python 随机数
    #np.random.seed(seed)              # Numpy 随机数
    torch.manual_seed(seed)           # CPU 随机数
    torch.cuda.manual_seed(seed)      # GPU 随机数
    torch.cuda.manual_seed_all(seed)  # 多 GPU 时

    # cuDNN 固定模式
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def plot_metrics_separate(AUC, LOSS):
    """
    分别绘制AUC和LOSS随epoch变化的折线图（两个子图）
    """
    # 设置全局字体大小
    plt.rcParams.update({'font.size': 14})  # 可以调整这个数值来改变字体大小
    
    epochs = range(1, len(AUC) + 1)
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 10))
    
    # 绘制AUC
    ax1.plot(epochs, AUC, color='red',  linewidth=2)
    ax1.set_xlabel('Epoch', fontsize=16)  # 单独设置x轴标签字体大小
    ax1.set_ylabel('AUC', fontsize=16)    # 单独设置y轴标签字体大小
    #ax1.set_title('AUC vs Epoch', fontsize=18, fontweight='bold')  # 设置标题字体大小
    ax1.grid(True, linestyle='--', alpha=0.7)
    ax1.tick_params(axis='both', which='major', labelsize=14)  # 设置刻度标签字体大小
    
    # 绘制LOSS
    ax2.plot(epochs, LOSS, color='blue', linewidth=2)
    ax2.set_xlabel('Epoch', fontsize=16)
    ax2.set_ylabel('LOSS', fontsize=16)
    #ax2.set_title('LOSS vs Epoch', fontsize=18, fontweight='bold')
    ax2.grid(True, linestyle='--', alpha=0.7)
    ax2.tick_params(axis='both', which='major', labelsize=14)
    
    plt.tight_layout()
    plt.show()

class CustomDataset(Dataset):
    def __init__(self,
                 X,
                 y):
        self.data=X
        self.targets=y
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return torch.from_numpy(self.data[idx]), (self.targets[idx])
def get_device():
    return 'cuda' if torch.cuda.is_available() else 'cpu'

device = get_device()

def read_gauss_data(file, normalization='z-score', train_level=1,test_level=2,seed=42):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #idx = np.random.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]

    x_train = x[train_norm_idx]
    data_dim=x_train.shape[1]
    y_train = y[train_norm_idx]
    noise_train= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(train_norm_idx), x.shape[1])
)  
    noise_test_abnorm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(anom_idx), x.shape[1])
)  
    noise_test_norm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(test_level),
    size=(len(test_norm_idx), x.shape[1])
)  
    x[anom_idx]=x[anom_idx]+noise_test_abnorm
    x[test_norm_idx]=x[test_norm_idx]+noise_test_norm
    x[train_norm_idx]=x[train_norm_idx]+noise_train
    x_train=x[train_norm_idx]
    x_test = x[np.hstack([test_norm_idx, anom_idx])]
    y_test = y[np.hstack([test_norm_idx, anom_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))
    sds=None
    # normalization
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)

    return x_train, y_train, x_test, y_test,data_dim,sds


def read_after_gauss_data(file, normalization='z-score', train_level=1,test_level=2,seed=42):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #idx = np.random.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]

    x_train=x[train_norm_idx]
    x_test = x[np.hstack([test_norm_idx, anom_idx])]
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)
    data_dim=x_train.shape[1]
    y_train = y[train_norm_idx]
    noise_train= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(train_norm_idx), x.shape[1])
)  
    noise_test_abnorm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(test_level),
    size=(len(anom_idx), x.shape[1])
)  
    noise_test_norm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(test_norm_idx), x.shape[1])
)  
    x[anom_idx]=x[anom_idx]+noise_test_abnorm
    x[test_norm_idx]=x[test_norm_idx]+noise_test_norm
    x[train_norm_idx]=x[train_norm_idx]+noise_train
    y_test = y[np.hstack([test_norm_idx, anom_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))
    #sds=None
    # normalization

    return x_train, y_train, x_test, y_test,data_dim,sds


def read_noise_data(file, normalization='z-score', seed=42,noise_rate=0):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #idx = np.random.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    ab_train=int(split*noise_rate)
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]
    train_ab_idx, test_ab_idx = anom_idx[:ab_train], anom_idx[ab_train:]
    x_train = x[np.hstack([train_norm_idx,train_ab_idx])]
    y_train = y[np.hstack([train_norm_idx, train_ab_idx])]
    data_dim=x_train.shape[1]
    x_test = x[np.hstack([test_norm_idx, test_ab_idx])]
    y_test = y[np.hstack([test_norm_idx, test_ab_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))
    sds=None
    # normalization
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)

    return x_train, y_train, x_test, y_test,data_dim,sds

def read_OD_data(file, normalization='z-score', seed=42):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]

    x_train = x
    #data_dim=x_train.shape[1]
    y_train = y

    #x_test = x[np.hstack([test_norm_idx, anom_idx])]
    #y_test = y[np.hstack([test_norm_idx, anom_idx])]

    #print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
    #      f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))

    # normalization
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        #x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
       # x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        #x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        #x_test= (x_test - mean)/(std + 1e-4)

    return x_train, y_train,x_train, y_train,x.shape[1],sds

def read_data(file, normalization='z-score', seed=42):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #idx = np.random.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]

    x_train = x[train_norm_idx]
    data_dim=x_train.shape[1]
    y_train = y[train_norm_idx]

    x_test = x[np.hstack([test_norm_idx, anom_idx])]
    y_test = y[np.hstack([test_norm_idx, anom_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))
    sds=None
    # normalization
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)

    return x_train, y_train, x_test, y_test,data_dim,sds
def read_front_data(file, normalization='z-score', seed=42):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    #rng = np.random.RandomState(seed)
    #idx = rng.permutation(np.arange(len(x)))
    #print(idx[0:10])
    #x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]

    x_train = x[train_norm_idx]
    y_train = y[train_norm_idx]
    data_dim=x_train.shape[1]

    x_test = x[np.hstack([test_norm_idx, anom_idx])]
    y_test = y[np.hstack([test_norm_idx, anom_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))

    # normalization
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)

    return x_train, y_train, x_test, y_test,data_dim,sds
def nonlinear_transform(x):
    """
    Nonlinear transformation from 2D to 100D for a tensor of input points.
    
    Args:
    - x: Input tensor of shape (N, 2), where N is the number of points.
    
    Returns:
    - y: Output tensor of shape (N, 100), where each row is the transformed vector.
    """
    # Create an empty output tensor
    N = x.shape[0]
    y = torch.zeros((N, 166))
    
    # Perform the transformation for each dimension from 1 to 100
    for i in range(1, 166):
        a_i = i * 0.1
        b_i = i * 0.05
        c_i = i * 0.01
        d_i = 1 / i
        y[:, i - 1] = (
            torch.sin(a_i * x[:, 0]**2 + b_i * x[:, 1]**3) +
            torch.cos(c_i * torch.exp(x[:, 0] * x[:, 1])) +
            torch.exp(-a_i*c_i*(x[:,0]**2+x[:,1]**2))+
            d_i * torch.log(1 + x[:, 0]**2 + x[:, 1]**2)
        )
    
    return y
def generate_unit_norm_matrix(rows, cols):
    """
    生成一个矩阵，确保每一行的2范数都等于1。
    
    Args:
    - rows: 矩阵的行数。
    - cols: 矩阵的列数。
    
    Returns:
    - matrix: 满足要求的矩阵。
    """
    # 随机生成矩阵
    matrix = np.random.randn(rows, cols)
    
    # 对每一行进行归一化处理，使得每一行的2范数等于1
    norms = np.linalg.norm(matrix, axis=1, keepdims=True)
    matrix = matrix / norms  # 每行除以对应的2范数
    
    return matrix

def min_max_normalize(x):
    filter_lst = []
    for k in range(x.shape[1]):
        s = np.unique(x[:, k])
        if len(s) <= 1:
            filter_lst.append(k)
    if len(filter_lst) > 0:
        print('remove features', filter_lst)
        x = np.delete(x, filter_lst, 1)

    scaler = MinMaxScaler()
    scaler.fit(x)
    x = scaler.transform(x)

    return x

def min_max_normalize(x):   
    filter_lst = []
    for k in range(x.shape[1]):
        s = np.unique(x[:, k])
        if len(s) <= 1:
            filter_lst.append(k)
    if len(filter_lst) > 0:
        print('remove features', filter_lst)
        x = np.delete(x, filter_lst, 1)

    scaler = MinMaxScaler()
    scaler.fit(x)
    x = scaler.transform(x)

    return x


import numpy as np

def max_gap_over_mean_row_norm(A):
    """
    A: 形状为 (n, n) 的 numpy 数组
    返回值:
        ratio: 最大相邻 gap / 行范数平均值
        max_gap: 最大相邻 gap
        gap_index: 产生最大 gap 的位置（在排序后的下标 k，
                   表示 gap = sorted_norms[k] - sorted_norms[k+1]）
    """
    # 1. 计算每一行的行范数（L2 范数）
    row_norms = A  # shape: (n,)

    if A.shape[0] < 2:
        # 只有一行时没有相邻 gap，可以按需要约定返回 0 或 np.nan
        return 0.0, 0.0, None

    # 2. 从大到小排序
    sorted_norms = np.sort(row_norms)[::-1]

    # 3. 计算相邻差值 (gap)
    gaps = sorted_norms[:-1] - sorted_norms[1:]  # shape: (n-1,)

    # 4. 最大 gap 及其位置
    max_gap_index = np.argmax(gaps)
    max_gap = gaps[max_gap_index]

    # 5. 行范数平均值
    mean_norm = row_norms.max()-row_norms.min()

    # 6. 比值
    ratio = max_gap / mean_norm if mean_norm != 0 else np.inf

    return ratio, max_gap, max_gap_index

import torch

def first_i_by_max_gap_ratio_1d(jac_sum: torch.Tensor):
    """
    jac_sum 是一维：
    1) 升序排序 a1..ad
    2) r_i=(a_{i+1}-a_i)/a_{i+1}
    3) 找 r_i 最大的 i*
    4) 返回原数组中前 i* 个元素的下标
    """
    assert jac_sum.dim() == 1, "jac_sum must be 1D"

    flat = jac_sum
    a_sorted, idx_sorted = torch.sort(flat, descending=False)  # a1..ad, idx_sorted 是原下标

    a_i = a_sorted[:-1]
    a_ip1 = a_sorted[1:]

    eps = 1e-12
    ratios = (a_ip1 - a_i) / (a_ip1 + eps)  # (d-1,)

    i_star0 = torch.argmax(ratios).item()   # 0-based 的 r_i 位置
    i_star1 = i_star0 + 1                  # 对应 1-based 的 i*

    # 前 i* 个元素（a1..ai*）在原数组中的下标
    idx_top = idx_sorted[:i_star1]         # 0-based 原下标

    return idx_top


def top_percent_positions(jac_sum: torch.Tensor, percent: float):
    """
    选出 jac_sum 中按数值从大到小的前 percent 比例元素的位置

    Returns:
        indices_nd: 形如 (k, jac_sum.dim()) 的长整型张量，每行是一个位置
        mask: 和 jac_sum 同形状的 bool mask，top percent 的位置为 True
        values: top percent 的数值
    """
    assert 0 < percent <= 1, "percent should be in (0, 1]"
    flat = jac_sum.reshape(-1)
    k = int(math.ceil(percent * flat.numel()))

    # topk 取最大的 k 个
    values, idx_flat = torch.topk(flat, k, largest=False, sorted=True)

    # unravel 成多维坐标
    indices_nd = torch.stack(torch.unravel_index(idx_flat, jac_sum.shape), dim=1)

    mask = torch.zeros_like(flat, dtype=torch.bool)
    mask[idx_flat] = True
    mask = mask.view_as(jac_sum)

    return indices_nd, mask, values


def contribution_calculation(model, train_loader, std, percent):
    i=0
    model = model
    jac_sum = None
    num = 0
    #std = std.to(device).float()

    for batch in tqdm(train_loader):
     #if i<=4:
        model = model.to(device)
        model.train()
        i=i+1
        imgs, lab = batch
        imgs = imgs.float().to(device)
        imgs.requires_grad = True

        outputs, _ = model(imgs, 0)

        num += outputs.shape[0]
        outVector = torch.sum(outputs, 0).view(-1)
        outdim = outVector.size(0)

        jac = torch.stack(
            [torch.autograd.grad(outVector[i], imgs, retain_graph=True, create_graph=False)[0]
             for i in range(outdim)],
            dim=0
        )  # (outdim, B, in_dim)
        jac = jac.permute(1, 0, 2)          # (B, outdim, in_dim)
        jac = jac @ std                    # (B, outdim, in_dim)  按你原逻辑
        jac = torch.abs(jac)

        if jac_sum is None:
            jac_sum = torch.sum(jac, dim=0)     # (outdim, in_dim)
        else:
            jac_sum += torch.sum(jac, dim=0)

    jac_sum = jac_sum / num
    jac_sum=torch.norm(jac_sum,dim=1)

    # === 取前 percent 最大的位置 ===
    top_idx, top_mask, top_values = top_percent_positions(jac_sum, percent)
    print(len(top_idx))
    top_idx=first_i_by_max_gap_ratio_1d(jac_sum)
    print(len(top_idx))

    return top_idx



def training(model,train_loader,lr,adam=None,wd=0,pun=0.01, grad_pun=0.01,PNAL=None,std=None,opt=None):
    model = model.to(device)
    model.device = device
    model.train()
    std=std.to(device)
    std=std.float()
    optimizer=adam
    #AE_pretrain(data_set=all_set,model=model,out=out,act=act,adam=adam,n_points=n_points,lr=0.001,bs=128,n_epoch=50,wd=0)
    if adam==None:
        optimizer = torch.optim.Adam(model.parameters(), lr = lr, weight_decay=wd,amsgrad=1)
    best_loss = 1000000
    loss_sum=0
    best_auc=0
    i=0
    je=None
    for batch in tqdm(train_loader):
     #if i<=4:
        optimizer.zero_grad()
        loss = 0.0
        rs=0
        imgs,_ = batch
        imgs = imgs.to(device)
        imgs=imgs.float()
        #imgs=imgs[:1024]
        imgs.requires_grad=True
        #kimgs=(mat@imgs.T).T
        outputs,sldj= model(imgs,sldj=0)
        log_likelihood=-0.5 * (torch.pow(outputs,2) + torch.log(torch.tensor(torch.pi*2)))
        #print(log_likelihood.shape)
        #likelihood=likelihood*model.alpha
        sample_likelihood=torch.sum(log_likelihood,dim=1)
        #print(sample_likelihood.shape)
        #likelihood=torch.sum(torch.prod(likelihood,dim=1)*model.alpha,dim=1)
        #print(likelihood.shape)
        #sample_likelihood=torch.log(likelihood)
        #log_likelihood=-0.5 * ((outputs-model.mu) ** 2 /(2*model.sigma**2)+ torch.log(2 * np.pi*model.sigma**2))#Add with jacobian to be true density
        #sample_likelihood=torch.sum(log_likelihood,dim=1)
        batch_size = outputs.size(0)
        outVector = torch.sum(outputs,0).view(-1)
        outdim = outVector.size()[0]
        if je==None:
           je=torch.zeros(outdim)
        jac=0
        F_norm=0
        L_2=1
        if grad_pun!=0:
           # begin=time.time()
            jac = torch.stack([torch.autograd.grad(outVector[i], imgs,
                                    retain_graph=True, create_graph=True)[0] for i in range(outdim)], dim=0)#last 784 is the input dimension
            #print(jac.shape)
            #dets=torch.stack([torch.det(jac[i]) for i in range(jac.shape[0])])
            #print(abs(dets).sum())
            #print(" Contribution mean to each z:"+str(torch.sum(abs(jac))))
            #jac=
            #jac=torch.pow(torch.sum(abs(jac),dim=2),2)
           # print(torch.mean(torch.sum(jac,dim=0)))
            #print(jac.shape)
            #jac=torch.sqrt(torch.sum(torch.pow(torch.sum(abs(jac),dim=2),2),dim=1))
            #print(jac.shape)
            #je+=jac.detach().cpu()
            #print(jac.shape)
            #jac=torch.pow(jac,2)
            #print(jac.shape)
            #jac_norm=torch.norm(jac,p=2,dim=2)
            #print(jac_norm<=0)
            #jac_norm=torch.mean(jac_norm,dim=1)
            #print(jac_norm.shape)
            #print(jac.shape)
            jac=jac.permute(1,0,2)
            jac=torch.matmul(jac,std)
            #print(jac.shape)
            jac=torch.abs(jac)
            jac=torch.mean(jac,dim=0)
            #print(jac.shape)
            #jac=torch.abs(jac)
            #jac=torch.abs(jac)
            if PNAL=='L_2':
             #print(" 1")
             #jac=torch.pow(jac,2)
             #jac=torch.sum(torch.sqrt(torch.sum(jac,dim=1)))
             jac_norm=torch.norm(jac,dim=1)
             jac=torch.sum(jac_norm)
             #print(jac_norm.shape)
             #L_2=1
            elif PNAL=='L_1sq':
             #print(" 2")
             #jac=torch.pow(jac,2)
             jac=torch.sum(torch.sqrt(torch.sum(jac,dim=1)))
            elif PNAL=='L2cL1':
             #jac=1/(jac+1)
             #print(" 3")
             jac=torch.pow(jac,2)
             L_2=torch.sqrt(torch.sum(jac))
             jac=torch.sum(torch.sqrt(torch.sum(jac,dim=1)))
             jac=jac/L_2
            elif PNAL=='L2-L1':
             #jac=1/(jac+1)
             #print(" 4")
             jac=torch.pow(jac,2)
             L_2=torch.sqrt(torch.sum(jac))
             jac=torch.sum(torch.sqrt(torch.sum(jac,dim=1)))
             jac=jac-L_2
            elif PNAL=='overL2cL1':
             #print(" 5")
             jac=1/(jac+1)
             jac=torch.pow(jac,2)
             L_2=torch.sqrt(torch.sum(jac))
             jac=torch.sum(torch.sqrt(torch.sum(jac,dim=1)))
             jac=jac/L_2
            elif PNAL=='overL2-L1':
             #print(" 6")
             jac=1/(jac+1)
             jac=torch.pow(jac,2)
             L_2=torch.sqrt(torch.sum(jac))
             jac=torch.sum(torch.sqrt(torch.sum(jac,dim=1)))
             jac=jac-L_2
            elif PNAL=='half/half':
             #print("7")
             row_norms = torch.norm(jac, p=2, dim=1)
             sorted_indices = torch.argsort(row_norms)
             half_size = len(row_norms) // 2
             min_half_indices = sorted_indices[:half_size]  
             max_half_indices = sorted_indices[half_size:] 
             min_half_sum = row_norms[min_half_indices].sum()
             max_half_square_sum = (row_norms[max_half_indices] ** 2).sum()
             jac=min_half_sum/max_half_square_sum
            elif PNAL=='half-half':
             #print(" 8")
             row_norms = torch.norm(jac, p=2, dim=1)
             sorted_indices = torch.argsort(row_norms)
             half_size = len(row_norms) // 2
             min_half_indices = sorted_indices[:half_size]  
             max_half_indices = sorted_indices[half_size:] 
             min_half_sum = row_norms[min_half_indices].sum()
             max_half_square_sum = (row_norms[max_half_indices] ** 2).sum()
             jac=min_half_sum-max_half_square_sum
             #jac=torch.sum(torch.sqrt(jac))
            #print(jac.shape)
            #jac=torch.sum(jac)
            #jac=torch.sum(jac)-F_norm
        loss=torch.mean(-(sample_likelihood+sldj))
        #print(sldj<=0)
        #print(jac)
        #print(jac)
        #print(sample_likelihood.max())
        #print("Sample density mean: "+str(sample_likelihood.mean()))
        #print("Determinants mean: "+str(sldj.mean()))
        #jac=jac/L_2
        loss+=(jac)*grad_pun
        #print(jac) 
       # print(jac/L_2)
        loss.backward()
        optimizer.step()
        #break
        #print(model.mu)
        loss_sum+=loss.detach().cpu()
        i+=1
        #print(trac.shape)
    #loss_sum=loss_sum/i/1024
     #if auc>best_auc:
     #   best_auc=auc
    #je=torch.stack(je)
    #print(je.shape)
    print(f" Train |  loss = {loss:.5f}  grad={jac:.5f}")
   # print(model.mu)
   # print(model.sigma)
    #print(model.alpha)
    return _,loss_sum

import torch
from tqdm import tqdm

def build_contributed_testset(model, test_loader, top_idx, device='cuda'):
    """
    Args:
        model: flow model with forward() and reverse()
        test_loader: DataLoader yielding (x, y)
        top_idx: (k, 2) 或 (k, ...) 的位置索引，只保留这些位置的 z
        CustomDataset: 你的 dataset 类
        device: 'cuda' or 'cpu'

    Returns:
        new_test_set: CustomDataset(new_test_data, new_test_lab)
        new_test_data: Tensor
        new_test_lab: Tensor
    """
    model = model.to(device)
    model.eval()

    new_x_list = []
    new_y_list = []

    # 预处理 top_idx，确保在 CPU 上的 long
    if not torch.is_tensor(top_idx):
        top_idx = torch.tensor(top_idx, dtype=torch.long)
    top_idx = top_idx.long()

    with torch.no_grad():
        for batch in tqdm(test_loader):
            x, y = batch
            x = x.to(device).float()
            y = y.to(device)

            # forward 得到 z
            outputs, sldj = model(x, sldj=0)   # outputs: (B, outdim) or (B, outdim, in_dim?) depends on your model

            # === 构造 mask 并置零 ===
            # 这里假设 outputs 的 shape 是 (B, D) 或 (B, outdim, in_dim) 中的后两维对应 jac_sum 的 shape
            z = outputs
            #print(z.shape)

            # 如果 z 是 (B, D)，那 top_idx 应该是一维索引；
            # 你现在 top_idx 是二维 (row, col)，说明 z 更可能是 (B, outdim, in_dim)
            if z.dim() == 2:
                # top_idx 需要是一维 flat index
                # 如果你确实只有二维 top_idx，那就 flatten 一下
                flat_size = z.size(1)
                flat_idx = top_idx.view(-1)
                mask = torch.zeros_like(z)
                mask[:, flat_idx] = 1.0
                #print(mask[0])
                #print(mask[1])
                z_masked = z * mask
            else:
                # z shape: (B, outdim, in_dim, ...)
                mask = torch.zeros_like(z)

                # top_idx 形如 (k, 2) -> (row, col)
                # 只在除 batch 维外的前两维打点
                rows = top_idx[:, 0]
                cols = top_idx[:, 1]
                mask[:, rows, cols] = 1.0

                z_masked = z * mask

            # === reverse 得到新的 x ===
            x_new, _ = model.reverse(z_masked, sldj=0)

            new_x_list.append(x_new.detach().cpu())
            new_y_list.append(y.detach().cpu())

    new_test_data = torch.cat(new_x_list, dim=0).numpy()
    new_test_lab  = torch.cat(new_y_list, dim=0).numpy()

    new_test_set = CustomDataset(new_test_data, new_test_lab)
    return new_test_set, new_test_data, new_test_lab


def testing(model,test_loader,top_idx,std): 
 model = model
 device = 'cuda'
 model = model.to(device)
 model.eval()
 preds,targets,contributed_preds,sldj_preds,pz_preds,train_preds,test_preds= [],[],[],[],[],[],[]

    
 for batch in tqdm(test_loader):
    x, y = batch
    x, y = x.to(device), y.to(device)
    x = x.to(device)
    x=x.float()
    x.requires_grad=True
    outputs,sldj= model(x,sldj=0)
    test_preds+=outputs.detach().cpu()
    #likelihood=1
    outVector = torch.sum(outputs,0).view(-1)
    outdim = outVector.size()[0]
    log_likelihood=-0.5 * (torch.pow(outputs,2))#Add with jacobian to be true density
    jac = torch.stack([torch.autograd.grad(outVector[i], x,
                                    retain_graph=True, create_graph=False)[0] for i in range(outdim)], dim=0)#last 784 is the input dimension
    jac=jac.permute(1,0,2)
    jac=torch.matmul(jac,std)
    jac_sel = jac[:, top_idx, :]
    # 2) J J^T : (B, m, m)
    jjt = jac_sel @ jac_sel.transpose(-1, -2)
    # 3) 为了避免 det=0/数值不稳，加一点 eps*I
    m = jjt.size(-1)
    eps=1e-8
    eye = torch.eye(m, device=jac.device).unsqueeze(0)  # (1, m, m)
    jjt_stable = jjt + eps * eye
    # 4) 求 det（推荐用 slogdet 更稳）
    sign, logabsdet = torch.linalg.slogdet(jjt_stable)  # (B,), (B,)
    det_vals = sign * torch.exp(logabsdet)  
    #likelihood=torch.exp(-(outputs.unsqueeze(-1)-model.mu)**2/(2*model.sigma**2)/torch.sqrt(2*np.pi*model.sigma**2))*(model.alpha/torch.sum(model.alpha))
    #likelihood=torch.sum(likelihood,dim=2)
    #sample_likelihood=torch.sum(torch.log(likelihood),dim=1)
    #contributed_sample_likelihood
    #sample_likelihood=torch.log(likelihood)
    #contributed_log_likelihood=log_likelihood*(jac_sum)
    #contributed_sample_likelihood=torch.sum(contributed_log_likelihood,dim=1)
    sample_likelihood=torch.sum(log_likelihood,dim=1)
    pred=-(sample_likelihood+det_vals)
    #pred=-sldj
    #pred=-(sample_likelihood)
    #pred[torch.isinf(pred)] =torch.min(pred)
    #sldj_preds+=-sldj.detach()
    #pz_preds+=-sample_likelihood.detach()
    preds+=pred
    targets += y
    gc.collect()  # 清理 Python 侧无引用对象
    torch.cuda.empty_cache()  # 释放 CUDA 缓存池里“空闲”的块（给其他程序用）
 i = 0
 auroc = AUROC(task="binary")
 average_precision = AveragePrecision(task="binary")
 targets=torch.stack(targets)
 preds=torch.stack(preds)
 roc_auc = auroc(preds, targets)
 ap_score = average_precision(preds, targets)
 
 return roc_auc,ap_score


def draw_fig(jac_sum,epoch):
# 假设你有一个 512 维的 tensor
  data = jac_sum
# 将 tensor 的数据转换为 numpy 并排序
  sorted_data = data
  sorted_data = torch.sort(data).values.numpy()
  #print(sorted_data.min())
# 绘制柱状图
  plt.figure(figsize=(20, 12))
  plt.bar(range(jac_sum.shape[0]), sorted_data)
  plt.xlabel('Index')
  plt.ylabel('Value')
  plt.title('Bar Plot of Sorted Tensor Values')
  #plt.show()
  plt.savefig('./jac_figs/'+str(epoch)+" .png", format='png')
  plt.close()
#torch.manual_seed(42)
#torch.cuda.manual_seed(42)

def draw_3D_fig(x,labels,epoch):
 fig = plt.figure()

# 创建一个 3D 坐标轴
 ax = fig.add_subplot(111, projection='3d')
 list=['normal data','noised normal data','anomalous data']
# 为每个标签分配不同的颜色
 unique_labels = np.unique(labels)  # 获取标签的唯一值
 colors = plt.cm.viridis(np.linspace(0, 1, len(unique_labels)))  # 使用viridis色图为每个标签分配一个颜色

    # 为每个标签绘制点并添加标签
 for i, label in enumerate(unique_labels):
        # 选择对应标签的点
        label_points = x[labels == label]
        ax.scatter(label_points[:, 0], label_points[:, 1], label_points[:, 2], s=5, color=colors[i], label=list[i],alpha=0.9)

# 设置坐标轴标签
 ax.set_xlabel('X')
 ax.set_ylabel('Y')
 ax.set_zlabel('Z')

# 添加颜色条
 ax.legend()

# 显示图形
 plt.show()
 #plt.savefig('./NVP_3D/'+str(epoch)+" .png", format='png')
 plt.close()
def draw_minmax_fig(jac_sum,epoch):
# 假设你有一个 512 维的 tensor
  tensor = jac_sum
  print(tensor.shape)
# 将 tensor 的数据转换为 numpy 并排序
  #sorted_data = data
  data_np = tensor.numpy()
  #data_np[10][100]=1000
  fig, ax = plt.subplots(figsize=(10, 10), dpi=1000)  # 正方形尺寸
  heatmap = ax.pcolormesh(data_np, cmap='coolwarm', shading='auto')
# 添加 colorbar，并调整其位置，使得主图保持正方形
  cbar = plt.colorbar(heatmap, ax=ax, fraction=0.046, pad=0.04)
# 显示图形
  plt.title("Square Heatmap without Colorbar")
  #plt.show()
  plt.savefig('./MIN-MAX-figs/'+str(epoch)+" .png", format='png')
  plt.close()
def knn_average_distance(train_data, test_data, k):
    """
    计算每个测试样本到其k个最近邻的平均距离。

    参数:
    train_data: torch.Tensor, 训练数据 (形状: num_train_samples x num_features)
    test_data: torch.Tensor, 测试数据 (形状: num_test_samples x num_features)
    k: int, 最近邻的数量

    返回:
    avg_distances: torch.Tensor, 每个测试样本到其k个最近邻的平均距离 (形状: num_test_samples)
    """
    # 计算测试数据到训练数据的欧氏距离矩阵
    distances = torch.cdist(test_data, train_data, p=2)  # 形状: [num_test_samples, num_train_samples]

    # 对每个测试样本找到距离最小的k个值
    k_nearest_distances, _ = torch.topk(distances, k, dim=1, largest=False)

    # 计算平均距离
    avg_distances = k_nearest_distances.mean(dim=1)

    return avg_distances


import torch

def verify_reverse_correctness(model, data_loader, device='cuda', atol=1e-5, rtol=1e-5, max_batches=3):
    """
    验证 model.reverse 是否为 forward 的正确逆映射

    Args:
        model: your _RealNVP
        data_loader: yields (x, y) or just x
        device: cuda/cpu
        atol, rtol: torch.allclose tolerance
        max_batches: 只测前几个 batch，够用且省时间

    Prints:
        reconstruction error stats
        optional sldj consistency stats
    """
    model = model.to(device)
    model.eval()

    recon_errors = []
    sldj_errors = []

    with torch.no_grad():
        for bi, batch in enumerate(data_loader):
            if bi >= max_batches:
                break

            # 兼容 (x,y) 或 x
            if isinstance(batch, (list, tuple)):
                x = batch[0]
            else:
                x = batch

            x = x.to(device).float()

            # forward
            z, sldj_fwd = model(x, sldj=0)

            # reverse
            x_rec, sldj_inv = model.reverse(z, sldj=0)

            # reconstruction error
            err = (x_rec - x).abs()
            recon_errors.append(err.mean().item())

            # sldj consistency（如果 reverse 里也更新 sldj）
            if sldj_fwd is not None and sldj_inv is not None:
                s_err = (sldj_inv + sldj_fwd).abs().mean().item()
                sldj_errors.append(s_err)

            # allclose check (per batch)
            ok = torch.allclose(x_rec, x, atol=atol, rtol=rtol)
            print(f"[batch {bi}] allclose(x_rec, x) = {ok}, mean|err| = {err.mean().item():.3e}, max|err| = {err.max().item():.3e}")

    print("\n=== Summary ===")
    print(f"Avg mean|x_rec-x| over {len(recon_errors)} batches: {sum(recon_errors)/len(recon_errors):.3e}")
    if len(sldj_errors) > 0:
        print(f"Avg mean|sldj_inv + sldj_fwd|: {sum(sldj_errors)/len(sldj_errors):.3e}")



def read_noise_data(file, normalization='z-score', seed=42,noise_rate=0):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #idx = np.random.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    ab_train=int(split*noise_rate)
    max_ab_train=int(split*0.1)
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]
    train_ab_idx, test_ab_idx = anom_idx[:ab_train], anom_idx[max_ab_train:]
    x_train = x[np.hstack([train_norm_idx,train_ab_idx])]
    y_train = y[np.hstack([train_norm_idx, train_ab_idx])]
    data_dim=x_train.shape[1]
    x_test = x[np.hstack([test_norm_idx, test_ab_idx])]
    y_test = y[np.hstack([test_norm_idx, test_ab_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))
    sds=None
    # normalization
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)

    return x_train, y_train, x_test, y_test,data_dim,sds



def read_after_gauss_data(file, normalization='z-score', train_level=1,test_level=2,seed=42):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #idx = np.random.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]
    x_train=x[train_norm_idx]
    x_test_norm= x[test_norm_idx]
    x_test_ab=x[anom_idx]
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test_norm = np.array([(xx - mus) / sds for xx in x_test_norm])
        x_test_ab = np.array([(xx - mus) / sds for xx in x_test_ab])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)
    data_dim=x_train.shape[1]
    y_train = y[train_norm_idx]
    noise_train= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(train_norm_idx), x.shape[1])
)  
    noise_test_abnorm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(anom_idx), x.shape[1])
)  
    noise_test_norm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(test_level),
    size=(len(test_norm_idx), x.shape[1])
)
    print(noise_test_norm)
    x_test_ab =x_test_ab +noise_test_abnorm
    x_test_norm=x_test_norm+noise_test_norm
    x_train=x_train+noise_train
    #x_train=x[train_norm_idx]
    x_test = np.concat([x_test_norm, x_test_ab])
    y_test = y[np.hstack([test_norm_idx, anom_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))
    #sds=None
    # normalization

    return x_train, y_train, x_test, y_test,data_dim,sds


def read_after_sgauss_data(file, normalization='z-score', train_level=1,test_level=2,seed=42):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #idx = np.random.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]
    x_train=x[train_norm_idx]
    x_test_norm= x[test_norm_idx]
    x_test_ab=x[anom_idx]
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test_norm = np.array([(xx - mus) / sds for xx in x_test_norm])
        x_test_ab = np.array([(xx - mus) / sds for xx in x_test_ab])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)
    data_dim=x_train.shape[1]
    y_train = y[train_norm_idx]
    noise_train= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(train_norm_idx), x.shape[1])
)  
    noise_test_abnorm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(test_level),
    size=(len(anom_idx), x.shape[1])
)  
    noise_test_norm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(test_norm_idx), x.shape[1])
)
    print(noise_test_norm)
    x_test_ab =x_test_ab +noise_test_abnorm
    x_test_norm=x_test_norm+noise_test_norm
    x_train=x_train+noise_train
    #x_train=x[train_norm_idx]
    x_test = np.concat([x_test_norm, x_test_ab])
    y_test = y[np.hstack([test_norm_idx, anom_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))
    #sds=None
    # normalization

    return x_train, y_train, x_test, y_test,data_dim,sds




act=2
bs=512#512512
epoch=100
lr=1e-3
mid_dim=2048
adam=1#AMAZON:59.1 no regularizer
pow=1
PNAL='L_1sq'#L_2,L_1sq,L2-L1,L2cL1,overL2-L1,overL2cL1,half-half,half/half
learn=False
mu=torch.zeros(1)
#set_seed(4)
sigma=torch.ones(1)
#data='30_satellite.npz'#bank should run  satellite:84  waveform:87
folder_path = Path('./datasets_small')
j=0
#for file_path in folder_path.rglob('*.npz'):  # 匹配所有文件或文件夹
for file_path in [#"datasets/4_breastw.npz",
                  "datasets/7_Cardiotocography.npz",
                  #"datasets/22_magic.gamma.npz",
                  #"datasets/29_Pima.npz",
                  #"datasets/30_satellite.npz",
                  #"datasets/35_SpamBase.npz",
                  #"datasets/46_WPBC.npz",
                  #"datasets/47_yeast.npz",
                  #"datasets/19_landsat.npz",
                  #"datasets/18_Ionosphere.npz"

]:#"datasets/18_Ionosphere.npz"]:
 #file_path=str(file_path)
#for k in range (0,1):
#for file_path in ["datasets/30_satellite.npz","datasets/7_Cardiotocography.npz","datasets/29_Pima.npz", "datasets/35_SpamBase.npz"]:
#for file_path in ["datasets/35_SpamBase.npz"]:#7_Cardiotocography,0.01,0.01 #35_SpamBase: 0.01,0.1#30_satellite:0.001,1#29_Pima:0.001,0.1#46_WPBC:0.005,0.1
 #file_path="datasets/34_smtp.npz"
 #if "Ionosphere" in file_path:
 #   j=1
 #if j==0:
 #   continue
 #file_path="datasets/24_mnist.npz"
 #first_slash_idx = file_path.find('\\')
 print(file_path)
 first_slash_idx = file_path.find('/')
 fit_ratio=0
 dot_idx = file_path.rfind('.')
 data = file_path[first_slash_idx + 1 : dot_idx]
 print(data)
 percent=0.4
 #for lr in [0.01]:
   #for grad_pun in [1]:#1#63.6
 for k in [1]:
  for lr in [0.005]:#,0.005,0.01]:#[0.001,0.005,0.01]:
   for grad_pun in [0.1]:#[0,0.01,0.1,1]:
    avg_auroc,avg_auprc,avg_con_auroc,avg_con_auprc,avg_f1,avg_con_f1=[],[],[],[],[],[]
    #lr = BEST_HP[data]["lr"]
    #grad_pun = BEST_HP[data]["grad_pun"]
    for i in range (0,5):
                losses,aucs=[],[]
                train_data,train_lab,test_data,test_lab,Input_dim,std=read_data(file_path)
                #train_data,train_lab,test_data,test_lab,Input_dim,std=read_OD_data(file_path,normalization='z-score')
                #train_data,train_lab,test_data,test_lab,Input_dim,std=read_after_sgauss_data(file_path,normalization='z-score',train_level=0.1,test_level=0.2)
                #train_data,train_lab,test_data,test_lab,Input_dim,std=read_after_gauss_data(file_path,normalization='z-score',train_level=0.1,test_level=0.2)#read_gauss_data(file_path,normalization='z-score',train_level=0.1,test_level=0.2)#read_data(file_path,normalization='z-score')
                std=torch.diag(torch.tensor(1/std)).cuda().float()
                std=torch.eye(Input_dim).cuda()
                print(std.shape)
                model=_RealNVP(input_dim=Input_dim,
                 mid_dim=mid_dim, 
                 masktype=0,
                 act=act,
                 mu=mu,
                 sigma=sigma,
                 learn=learn
             )
                best_auc,best_con_auc,best_sldj_auc,best_pz_auc,best_auprc,best_con_auprc,best_f1,best_con_f1=0,0,0,0,0,0,0,0
                time1=time.time()
                for n_epoch in range(epoch):
                 print(f"[EPOCH: {n_epoch:.1f}]")
                 torch.cuda.empty_cache()
                #std=torch.eye(Input_dim)
                #train_data,train_lab,test_data,test_lab,Input_dim=read_front_data('datasets/'+data,normalization='z-score')
                 #random_indices = np.random.choice(train_data.shape[0], size=2048, replace=False)
                 #train_set,test_set=CustomDataset(train_data[random_indices, :], np.ones(random_indices.shape)), CustomDataset(test_data, test_lab)
                 train_set,test_set=CustomDataset(train_data, train_lab), CustomDataset(test_data, test_lab)
                 optimizer = torch.optim.Adam(model.parameters(), lr = lr, weight_decay=0,amsgrad=1)
                 train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=bs,
        shuffle=True,
        num_workers=0
    )
                 test_loader = torch.utils.data.DataLoader(
        dataset=test_set, 
        batch_size=2048,
        shuffle=True,
        num_workers=0
    )           
                 con_loader = torch.utils.data.DataLoader(
        dataset=train_set, 
        batch_size=512,
        shuffle=True,
        num_workers=0
    )           
                 je,loss=training(
                        model=model,
                        train_loader=train_loader,
                        lr=lr,
                        adam=optimizer,
                        wd=0,
                        pun=0,
                        grad_pun=grad_pun,
                        PNAL=PNAL,
                        std=std      
                        )
                 losses.append(loss.detach().cpu())
                 #verify_reverse_correctness(model, train_loader, device='cuda', atol=1e-5, rtol=1e-5, max_batches=5)
                 top_idx=contribution_calculation(model,con_loader,std,percent=percent)
                 print("SOURCE IDX: "+str(len(top_idx)))
                 time2=time.time()
                 new_test_set,_,_=build_contributed_testset(model,test_loader,top_idx)
                 new_test_loader = torch.utils.data.DataLoader(
        dataset=new_test_set, 
        batch_size=128*4,
        shuffle=True,
        num_workers=0
    )           
                 print("allocated:", torch.cuda.memory_allocated() / 1024**2, "MB")
                 print("reserved :", torch.cuda.memory_reserved() / 1024**2, "MB")
                 auc,prc=testing(model,new_test_loader,top_idx,std)
                 aucs.append(auc)#.detach().cpu())
                 best_auc=max(auc,best_auc)
                 best_auprc=max(best_auprc,prc)
                 #print(f"AUROC:  {best_auc:.3f}, AUPRC: {best_auprc}")
                print("###TIME: "+str((time2-time1)*100))
                #plot_metrics_separate(aucs,losses)
                avg_auroc.append(best_auc)
                avg_auprc.append(best_auprc)
    avg_auroc=torch.mean(torch.stack(avg_auroc))
    avg_auprc=torch.mean(torch.stack(avg_auprc))
    print(f"AVG AUROC:  {avg_auroc:.3f}, AVG AUPRC: {avg_auprc}")