import torch
import numpy as np
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import TruncatedSVD
from torch_geometric.nn import knn_graph
from torch_geometric.data import Data, Dataset
import torch.optim as optim
import pandas as pd
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
from torch_geometric.loader import DataLoader#
from model import GCN_geo as GCN
import time
import dgl
#from torch.utils.data import Dataset, DataLoader
import dgl.heterograph
import matplotlib.pyplot as plt
from model import GCN_TF as GCN_TF,DW_GCN,DW_GCNTF,DW_GCNTF_KDE,TF_P_GIN
from sklearn.utils import resample
import gc
#from torch_geometric.graphgym.config import (cfg, dump_cfg,
#                                             # set_agg_dir,
#                                             set_cfg, load_cfg,
#                                             makedirs_rm_exist)
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support
import os
#import posenc_stats
#import posenc_config
from torchmetrics import AUROC, AveragePrecision
#os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
#os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
#网络调到更大，数据用更多，在image data上训练和测试
#dgl.heterograph.DGLGraph = dgl.heterograph.DGLHeteroGraph
def normalize_cdist(cdist):
    n, _ = cdist.shape
    up_tri = cdist[np.triu_indices_from(cdist, k=1)]
    mean = np.mean(up_tri)
    std = np.std(up_tri)
    cdist = (cdist - mean) / (std + 1e-9)
    cdist = cdist - np.eye(n) * (- mean / (std + 1e-9))
    return cdist
def get_device():
    return 'cuda' if torch.cuda.is_available() else 'cpu'

def Hbeta(D=np.array([]), beta=1.0):
    """
        Compute the perplexity and the P-row for a specific value of the
        precision of a Gaussian distribution.
    """

    # Compute P-row and corresponding perplexity
    P = np.exp(-D.copy() * beta)
    sumP = sum(P)+1e-10
    H = np.log(sumP) + beta * np.sum(D * P) / sumP
    P = P / sumP
    return H, P


def x2p(X=np.array([]), tol=1e-2, perplexity=30.0):
    """
        Performs a binary search to get P-values in such a way that each
        conditional Gaussian has the same perplexity.
    """

    # Initialize some variables
    print("Computing pairwise distances...")
    (n, d) = X.shape
    sum_X = np.sum(np.square(X), 1)
    D = np.add(np.add(-2 * np.dot(X, X.T), sum_X).T, sum_X)
    P = np.zeros((n, n))
    beta = np.ones((n, 1))
    logU = np.log(perplexity)

    # Loop over all datapoints
    for i in range(n):

        # Print progress
        if i % 500 == 0:
            print("Computing P-values for point %d of %d..." % (i, n))

        # Compute the Gaussian kernel and entropy for the current precision
        betamin = -np.inf
        betamax = np.inf
        Di = D[i, np.concatenate((np.r_[0:i], np.r_[i+1:n]))]
        (H, thisP) = Hbeta(Di, beta[i])

        # Evaluate whether the perplexity is within tolerance
        Hdiff = H - logU
        tries = 0
        while np.abs(Hdiff) > tol and tries < 50:

            # If not, increase or decrease precision
            if Hdiff > 0:
                betamin = beta[i].copy()
                if betamax == np.inf or betamax == -np.inf:
                    beta[i] = beta[i] * 2.
                else:
                    beta[i] = (beta[i] + betamax) / 2.
            else:
                betamax = beta[i].copy()
                if betamin == np.inf or betamin == -np.inf:
                    beta[i] = beta[i] / 2.
                else:
                    beta[i] = (beta[i] + betamin) / 2.

            # Recompute the values
            (H, thisP) = Hbeta(Di, beta[i])
            Hdiff = H - logU
            tries += 1

        # Set the final row of P
        P[i, np.concatenate((np.r_[0:i], np.r_[i+1:n]))] = thisP

    # Return final P-matrix
    print("Mean value of sigma: %f" % np.mean(np.sqrt(1 / beta)))
    #P=P+np.diag(np.ones(P.shape[0]))
    P = P + np.transpose(P)
    P = P / np.sum(P)
    P = P * 4.									# early exaggeration
    P = np.maximum(P, 1e-12)  
    return P
def balance_data(X, y, target_ratio=0.1):
    """ 调整数据集，使得标签1的比例为 target_ratio """
    # 获取类别索引
    idx_1 = np.where(y == 1)[0]  # 标签为1的索引
    idx_0 = np.where(y == 0)[0]  # 标签为0的索引

    # 计算目标1类样本数量
    target_n1 = int(len(y) * target_ratio)
    
    if len(idx_1) > target_n1:
        # 如果 1 类过多，则欠采样
        idx_1_resampled = resample(idx_1, replace=False, n_samples=target_n1, random_state=42)
        idx_0_resampled = idx_0  # 0类保持不变
    else:
        # 如果 1 类过少，则过采样
        idx_1_resampled = resample(idx_1, replace=True, n_samples=target_n1, random_state=42)
        idx_0_resampled = resample(idx_0, replace=False, n_samples=int(target_n1 / target_ratio * (1 - target_ratio)), random_state=42)

    # 合并并打乱
    new_indices = np.concatenate([idx_1_resampled, idx_0_resampled])
    np.random.shuffle(new_indices)

    return X[new_indices], y[new_indices]
# 定义高斯核函数
def gaussian_kernel(distances, sigma=1.0):
    return np.exp(-distances ** 2 / (2 * sigma**2))
# 计算 SVD 并提取特征的函数de
def resam(X_sub,y_sub,n_samples):
 if len(X_sub) > n_samples:
    # 计算 0 和 1 的比例
    ratio_1 = np.mean(y_sub)  # 计算 1 的比例
    n_1 = int(n_samples * ratio_1)  # 计算 1 的目标数量
    n_0 = n_samples - n_1  # 计算 0 的目标数量

    # 分别采样 0 和 1
    X_0 = X_sub[y_sub == 0]
    X_1 = X_sub[y_sub == 1]

    X_0_resampled, y_0_resampled = resample(X_0, np.zeros(len(X_0)), n_samples=n_0, random_state=rand, replace=False)
    X_1_resampled, y_1_resampled = resample(X_1, np.ones(len(X_1)), n_samples=n_1, random_state=rand, replace=False)

    # 合并采样后的数据
    X_sub = np.vstack((X_0_resampled, X_1_resampled))
    y_sub = np.hstack((y_0_resampled, y_1_resampled))
 return X_sub,y_sub

def read_OD_data(file, normalization='z-score', seed=42):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]

    x_train = x
    #data_dim=x_train.shape[1]
    y_train = y

    #x_test = x[np.hstack([test_norm_idx, anom_idx])]
    #y_test = y[np.hstack([test_norm_idx, anom_idx])]

    #print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
    #      f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))

    # normalization

    if normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
       # x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        #x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        #x_test= (x_test - mean)/(std + 1e-4)
    print("data shape: "+str(x_train.shape))

    return x_train, y_train

def get_complet_graph(X,pairwise_dist, y,sigma=1.0, k=16,method='gaussian'):
    N = pairwise_dist.shape[0]
    if method == 'gaussian':
        #pairwise_dist=normalize_cdist(pairwise_dist)
        pairwise_dist=torch.from_numpy(pairwise_dist)
        bandwidth = torch.mean(pairwise_dist) * sigma  # Standard deviation as scaling factor
        gaussian_weights = torch.exp(-pairwise_dist ** 2 / (2 * bandwidth ** 2))  # Gaussian similarity
        weights_a = gaussian_weights
        similarity_matrix=weights_a.numpy()
        #np.fill_diagonal(similarity_matrix,0)
        U,S, Vh = torch.linalg.svd(torch.from_numpy(similarity_matrix))
        S = S[:k].cpu()
        U = U[:, :k].cpu()
        Vh = Vh[:k, :].cpu()
        encodings = U[:, :, None] * np.sqrt(S)[:, None]
        #SVD_pos_enc = get_svd_decomp_stats(u.cpu(), s.cpu(), vh.cpu(), cfg.posenc_SVD.eigen.max_freqs, 'L2')
        X_svd = encodings.reshape(encodings.shape[0], -1)
        #svd = TruncatedSVD(n_components=k,n_iter=10)  # 假设我们保留32个主成分
        #svd.fit(similarity_matrix)
        #V = svd.components_.T  # 右奇异矩阵
        #S=svd.singular_values_
        #S_sqrt = np.sqrt(S)
        #S_sqrt=np.diag(S_sqrt)
        #X_svd = np.matmul(V, S_sqrt)
        print(X_svd.shape)
    elif method == 'xepressive':
        #pairwise_dist=normalize_cdist(pairwise_dist)
        pairwise_dist=torch.from_numpy(pairwise_dist)
        bandwidth = torch.mean(pairwise_dist) * sigma  # Standard deviation as scaling factor
        gaussian_weights = torch.exp(-pairwise_dist ** 2 / (2 * bandwidth ** 2))  # Gaussian similarity
        A=gaussian_weights+torch.eye(gaussian_weights.shape[0])
        A=torch.inverse(A)
        gaussian_weights=A@gaussian_weights
        weights_a = gaussian_weights
        similarity_matrix=weights_a.numpy()
        #np.fill_diagonal(similarity_matrix,0)
        U,S, Vh = torch.linalg.svd(torch.from_numpy(similarity_matrix))
        S = S[:k].cpu()
        U = U[:, :k].cpu()
        Vh = Vh[:k, :].cpu()
        encodings = U[:, :, None] * np.sqrt(S)[:, None]
        #SVD_pos_enc = get_svd_decomp_stats(u.cpu(), s.cpu(), vh.cpu(), cfg.posenc_SVD.eigen.max_freqs, 'L2')
        X_svd = encodings.reshape(encodings.shape[0], -1)
        #svd = TruncatedSVD(n_components=k,n_iter=10)  # 假设我们保留32个主成分
        #svd.fit(similarity_matrix)
        #V = svd.components_.T  # 右奇异矩阵
        #S=svd.singular_values_
        #S_sqrt = np.sqrt(S)
        #S_sqrt=np.diag(S_sqrt)
        #X_svd = np.matmul(V, S_sqrt)
        print(X_svd.shape)
    elif method == 'tsne':
        #pairwise_dist=normalize_cdist(pairwise_dist)
        similarity_matrix=x2p(X,perplexity=sigma)#没有考虑自身节点
        weights_a=torch.from_numpy(similarity_matrix)
        #np.fill_diagonal(similarity_matrix,0)
        U,S, Vh = torch.linalg.svd(torch.from_numpy(similarity_matrix))
        S = S[:k].cpu()
        U = U[:, :k].cpu()
        Vh = Vh[:k, :].cpu()
        encodings = U[:, :, None] * np.sqrt(S)[:, None]
        #SVD_pos_enc = get_svd_decomp_stats(u.cpu(), s.cpu(), vh.cpu(), cfg.posenc_SVD.eigen.max_freqs, 'L2')
        X_svd = encodings.reshape(encodings.shape[0], -1)
        #svd = TruncatedSVD(n_components=k,n_iter=10)  # 假设我们保留32个主成分
        #svd.fit(similarity_matrix)
        #V = svd.components_.T  # 右奇异矩阵
        #S=svd.singular_values_
        #S_sqrt = np.sqrt(S)
        #S_sqrt=np.diag(S_sqrt)
        #X_svd = np.matmul(V, S_sqrt)
        print(X_svd.shape)
    elif method == 'student_t':
        a = 1 + pairwise_dist ** 2
        weights_a = a ** (-1)
    elif method == 'linear':
        # a = (torch.max(pairwise_dist) - pairwise_dist) / torch.max(pairwise_dist)
        a = 1 / pairwise_dist
        mask = torch.eye(a.shape[0], dtype=torch.bool, device=a.device)
        weights_a = a.masked_fill(mask, float(1.0))
    elif method == 'softmax':
        inv_zdist = -1 * pairwise_dist
        mask = torch.eye(inv_zdist.shape[0], dtype=torch.bool, device=inv_zdist.device)
        inv_zdist = inv_zdist.masked_fill(mask, float('-inf'))
        weights_a = torch.nn.functional.softmax(inv_zdist, dim=1) + 1e-9
        weights_a = weights_a.masked_fill(mask, float(1.0))
    else:
        raise NotImplementedError

    src, dst = torch.where(torch.isnan(weights_a) == False)
    edge_weights = weights_a[src, dst]
    g = dgl.graph((src, dst), num_nodes=N)
    g.edata['weight'] = edge_weights.to(torch.float32)
    g.ndata['y']=torch.tensor(y).to(torch.float32)
    graph_pe=torch.tensor(X_svd, dtype=torch.float32)
    g.ndata['emb']=graph_pe.to(torch.float32)
    return g

def process_data(X, y, k=10, sigma=1.0,method="gaussian"):
    # 标准化数据
    scaler = StandardScaler()
    #X_scaled = scaler.fit_transform(X)
    X_scaled=X
    # 计算数据点之间的欧氏距离
    pairwise_dist = euclidean_distances((X), (X))
    g=get_complet_graph(X,pairwise_dist,y,sigma=sigma,k=k,method=method)
    return g
def process_all_data(X,y,k,sigma,method="gaussian"):
    #set_cfg(cfg)
    #posenc_config.set_cfg_posenc(cfg)
    data1=process_data(X,y,k,sigma[0],method=method)
    data2=process_data(X,y,k,sigma[1],method=method)
    data3=process_data(X,y,k,sigma[2],method=method)
    data4=process_data(X,y,k,sigma[3],method=method)
    data5=process_data(X,y,k,sigma[4],method=method)
    data=Data(x_1=data1,x_2=data2,x_3=data3,x_4=data4,x_5=data5)
    #print("EMBBB "+str(data.x_1.x.shape))
    #data4=process_data(X,y,k,sigma[3])
    #data5=process_data(X,y,k,sigma[4])
    #data=Data(x_1=data1,x_2=data2,x_3=data3,x_4=data4,x_5=data5)
    return data

class CustomGraphDataset(Dataset):
    def __init__(self, root, data_list):
        self.data_list = data_list
        super().__init__(root)

    def len(self):
        return len(self.data_list)

    def get(self, idx):
        return self.data_list[idx]
    
class LargeDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.file_list = self._load_file_list()  # 预先加载文件列表

    def _load_file_list(self):
        # 获取数据目录下的所有文件（可以根据实际情况修改）
        import os
        return [os.path.join(self.data_dir, f) for f in os.listdir(self.data_dir) if f.endswith('.tar')]

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        # 仅在访问时加载数据，避免一次性加载整个数据集
        data = torch.load(self.file_list[idx])
        return data
torch.cuda.empty_cache()
train_data,test_data=[],[]
feature_dim=256
in_channels = feature_dim  # 例如，节点特征的维度
hidden_channels = 1024# 隐藏层的大小#64
out_channels =64# 例如，节点分类的类别数（假设是二分类）#32
TF_layers=6#4
TF_nhead=4
TF_dim=1024
lamda=0.0
TF_drop=0
TF_act="gelu"
pp=None
write_path = "./loggers/UniOD "+'.txt'
saved=1
n_epoch=50
sigmas=[0.3,0.5,1,3,5]
datasets_name=["5_campaign","1_ALOI","16_http","23_mammography","32_shuttle","36_speech","6_cardio","10_cover","22_magic.gamma","8_celeba","33_skin","34_smtp","20_letter","40_vowels","44_Wilt"]
test_datasets_name=["30_satellite",'27_PageBlocks',"28_pendigits","17_InternetAds","19_landsat","26_optdigits","7_Cardiotocography","29_Pima","35_SpamBase","38_thyroid","43_WDBC","41_Waveform","44_Wilt","4_breastw","31_satimage-2"]
device = get_device()
print(device)
if saved==0:
 for filepath in datasets_name:
    X,y=read_OD_data("./datasets/Classical/"+filepath+".npz",'z-score')
    for rand in range(42,42+2):
        #X_sub,y_sub=X,y
        X_sub, _, y_sub, _ = train_test_split(X, y, test_size=0.4,random_state=rand,stratify=y)
        X_sub,y_sub=resam(X_sub,y_sub,6000)
        pe=process_all_data(X_sub,y_sub,k=256,sigma=sigmas)
        torch.save(pe, './Processed Graph/train_data/!train '+filepath+' '+str(rand)+'2000.tar')#m2 means subsets, m1 whole sets
 #       train_data.append(pe)
 for filepath in test_datasets_name:#["35_SpamBase"]:["30_satellite","28_pendigits","7_Cardiotocography"]:
    X,y=read_OD_data("./datasets/Classical/"+filepath+".npz","z-score")
    rand=42
    pe=process_all_data(X,y,k=256,sigma=sigmas)
    torch.save(pe, './Processed Graph/test_data/test '+filepath+'_or.tar')#m2 means subsets, m1 whole sets
 #   test_data.append(pe)
 #torch.save(train_data, '../autodl-tmp/data/sub_wa_train_la_dataset.tar')#m2 means subsets, m1 whole sets
 #torch.save(test_data, '../autodl-tmp/data/wa_test_la_dataset.tar')
else:
 for filepath in datasets_name:
    #X,y=read_OD_data("./datasets/Classical/"+filepath+".npz",'z-score')
    for rand in range(42,42+5):
      pass
      #train_data.append(torch.load('../autodl-tmp/data/train_data/train '+filepath+' '+str(rand)+'.tar'))
 for filepath in test_datasets_name:
    pass
    #test_data.append(torch.load('../autodl-tmp/data/test_data/test '+filepath+'_or.tar'))
#train_data = CustomGraphDataset(root='../autodl-tmp/data/', data_list=train_data)
train_data=LargeDataset('./Processed Graph/train_data')
print(len(train_data))
#test_data = CustomGraphDataset(root='../autodl-tmp/data/', data_list=test_data)
test_data=LargeDataset('./Processed Graph/test_data')
print("LOADED")
print(len(test_data))
train_loader = DataLoader(train_data, batch_size=2, shuffle=True,**{"num_workers": 2})
test_loader = DataLoader(test_data, batch_size=1, shuffle=False,**{"num_workers": 1})
#a=os.cpu_count()
print("CPU NUMBERS: "+str(os.cpu_count()))
# 使用 Adam 优化器
#os.makedirs(write_path, exist_ok=True)
#write = os.path.join(folder_path, "example.txt")
# 创建 GCN 模型实例
#model = DW_GCN(in_channels, hidden_channels, out_channels,n_layers=2)
model=TF_P_GIN(in_channels, hidden_channels, out_channels,TF_layers,TF_nhead,TF_act,TF_dim,n_layers=4,positive_processing=pp,lamda=lamda)
model=model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-5*5, weight_decay=1e-6,amsgrad=False)
#optimizer = optim.SGD(model.parameters(),lr=1e-3,momentum=0.5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epoch)
#criterion = torch.nn.BCELoss()
#criterion=torch.nn.CrossEntropyLoss()
test_auroc=0
train_auc=[]
test_auc=[]
best_auc=0
#print(len(train_auc))
test_auroc=0
train_auc=[]
test_auc=[]
best_auc=0
auroc=AUROC(task="binary")
auprc=AveragePrecision(task="binary")
#print(len(train_auc))
for epoch in range(n_epoch):  # 假设训练200轮
    model.train()
    total_loss = 0
    total_percent_loss=0
    total_auroc=[]
    time1=time.time()
    d=2
    for data in train_loader:
        optimizer.zero_grad()
        #print(data)
        data=data.to(device)
        #print((data.x_1))
        #print(data)
        #print(data.edge_attr.shape)
        #print(data.batch)
        #out = model(data)  # 输出节点分类的 logits
        out,percent_loss,preds,targets=model.multi_forward(data)
        loss=out
        #total_auroc.append(0)
        #targets=targets.detach().cpu().numpy()
        #preds=preds.detach().cpu().numpy()
        auc=0
        total_loss += loss.item()
        total_percent_loss+=percent_loss.item()
        total_auroc.append(auc)
        #d=d+1
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数
        #d=0
        #del data
        #del out
        #gc.collect()
        #data=data.cpu()
        #out=out.cpu()
        #torch.cuda.empty_cache()  # 可选：清理显存
    #torch.cuda.empty_cache()
    #scheduler.step()
    time2=time.time()
    take=time2-time1
    print("Training time caused: "+str(take))
    t_auroc=[]
    t_auprc=[]
    #scheduler.step()
    model.eval()
    if epoch==49:
     for data in test_loader:
        with torch.no_grad():
            data=data.to(device)
            out=model.train_test(data)
            #data=data.cpu()
            #out=out.cpu()
            #print(out.shape)
            auc=auroc(out,(data.x_1[0].ndata['y']))
            prc=auprc(out,(data.x_1[0].ndata['y'].int()))
            t_auroc.append(auc)
            t_auprc.append(prc)
            with open(write_path, "a", encoding="utf-8") as f:
             f.write("AUROC: "+str(auc)+"\n")  # 文件自动关闭
            print("AUROC: "+str(auc)+" AUPRC: "+str(prc))
            #del data
            #del out
            #gc.collect()
            #torch.cuda.empty_cache()
    total_auroc=sum(total_auroc)/len(total_auroc)
    #total_auroc=total_auroc.detach().cpu().numpy()
    if t_auroc !=[]:
     t_auroc=torch.mean(torch.stack(t_auroc))
     t_prc=torch.mean(torch.stack(t_auprc))
    else:
     t_auroc=0
    best_auc=max(best_auc,t_auroc)
    #if (best_auc==t_auroc):
    #    torch.save(model,'./data/best model.pt')
    #    print("SAVED")
    print(f'Epoch {epoch+1}, Loss: {total_loss:.4f}')
    with open(write_path, "a", encoding="utf-8") as f:
     f.write(f'Epoch {epoch+1}, Loss: {total_loss:.4f}'+'\n')
epochs=range(n_epoch)