import os
import sys
print(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from GAOOD.detector.mybase import DeepDetector
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from GAOOD.metric import *
import os
from torch_geometric.nn import GCNConv, global_mean_pool, GINConv

from ..nn import dgmmd
import warnings
import random
from sklearn.metrics import roc_auc_score

# 设置全局随机种子以确保复现性
def set_global_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

class NegativeSampleGenerator:
    def __init__(self, 
                 perturbation_methods=['feature_swap', 'feature_noise', 'structure'],
                 attr_swap_ratio=0.2,
                 spectral_perturb_alpha=0.1,
                 spectral_perturb_ratio=0.2,
                 feature_noise_std=0.1,
                 spectral_k=20,
                 threshold=0.5,
                 verbose=False):
        self.perturbation_methods = perturbation_methods
        self.attr_swap_ratio = attr_swap_ratio
        self.spectral_perturb_alpha = spectral_perturb_alpha
        self.spectral_perturb_ratio = spectral_perturb_ratio
        self.feature_noise_std = feature_noise_std
        self.spectral_k = spectral_k
        self.threshold = threshold
        self.verbose = verbose

    def generate_negative_graph(self, g):
        """生成一个负样本图"""
        g_neg = g.clone()
        
        # 特征扰动
        if 'feature_swap' in self.perturbation_methods or 'feature_noise' in self.perturbation_methods:
            g_neg = self.feature_perturbation(g_neg)
            
        # 结构扰动
        if 'structure' in self.perturbation_methods:
            g_neg = self.spectral_structure_perturbation(g_neg)
            
        return g_neg

    def feature_perturbation(self, g):
        """特征扰动：随机交换节点特征并添加高斯噪声"""
        x = g.x.clone()
        num_nodes = x.size(0)
        
        # 随机交换节点特征
        if 'feature_swap' in self.perturbation_methods:
            num_swaps = int(num_nodes * self.attr_swap_ratio)
            if num_swaps > 1:
                idx = torch.randperm(num_nodes)[:num_swaps]
                swapped_idx = torch.roll(idx, shifts=1)
                x[idx] = x[swapped_idx]
        
        # 添加高斯噪声
        if 'feature_noise' in self.perturbation_methods:
            noise = torch.randn_like(x) * self.feature_noise_std
            x = x + noise
        
        g.x = x
        return g

    def spectral_structure_perturbation(self, g):
        """结构扰动：通过谱分析放大高频信息"""
        edge_index = g.edge_index
        num_nodes = g.num_nodes
        
        # 构建邻接矩阵
        adj = torch.zeros((num_nodes, num_nodes))
        adj[edge_index[0], edge_index[1]] = 1
        adj = adj.numpy()
        
        # 特征值分解
        try:
            eigenvalues, eigenvectors = np.linalg.eigh(adj)
        except:
            return g.clone()
            
        # 选择高频特征向量
        high_freq_count = max(1, int(self.spectral_k * self.spectral_perturb_ratio))
        eigenvectors[:, -high_freq_count:] *= (1 + self.spectral_perturb_alpha)
        
        # 重构邻接矩阵
        adj_perturbed = eigenvectors @ np.diag(eigenvalues) @ eigenvectors.T
        adj_perturbed = (adj_perturbed + adj_perturbed.T) / 2
        adj_perturbed[adj_perturbed < 0] = 0
        
        # 二值化
        adj_perturbed = (adj_perturbed > self.threshold).astype(float)
        
        # 转换回PyG格式
        edge_index_new = torch.from_numpy(np.where(adj_perturbed > 0)).long()
        g_new = g.clone()
        g_new.edge_index = edge_index_new
        
        return g_new

def batch_graphs(graph_list):
    """将图列表打包成批处理"""
    return torch.utils.data.dataloader.default_collate(graph_list)

class AnomalyDetectionLoss(nn.Module):
    def __init__(self, epsilon=1e-6):
        super().__init__()
        self.epsilon = epsilon

    def forward(self, P_normal, P_negative):
        loss = -torch.mean((P_normal - P_negative) / (P_normal + self.epsilon))
        return loss


class DGMMDKDE(DeepDetector):
    def __init__(self,
                 in_dim,
                 hidden_dim=64,
                 out_dim=64,
                 num_layers=2,
                 bandwidths=[1.0],
                 dropout=0.1,
                 batch_norm=True,
                 kde_dimension=1,
                 args=None,
                 **kwargs):
        super(DGMMDKDE, self).__init__(in_dim=in_dim)
        
        self.hidden_dim = hidden_dim
        self.out_dim = out_dim
        self.num_layers = num_layers
        self.bandwidths = bandwidths
        self.dropout = dropout
        self.batch_norm = batch_norm
        self.kde_dimension = kde_dimension
        self.args = args
        self.build_save_path()
        
        # 负样本生成器
        self.neg_gen_sub = NegativeSampleGenerator(
            perturbation_methods=['feature_swap','structure'],
            attr_swap_ratio=0.2,
            spectral_perturb_alpha=-0.4,
            spectral_perturb_ratio=0.3,
            threshold=0.5
        )
        
        self.neg_gen_add = NegativeSampleGenerator(
            perturbation_methods=['feature_swap','structure'], 
            attr_swap_ratio=0.2,
            spectral_perturb_alpha=0.5,
            spectral_perturb_ratio=0.3,
            threshold=0.5
        )

    def build_save_path(self):
        path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        if self.args.exp_type == 'oodd':
            path = os.path.join(path, 'model_save', self.args.model, self.args.exp_type, self.args.DS_pair)
        elif self.args.DS.startswith('Tox21'):
            path = os.path.join(path, 'model_save', self.args.model, self.args.exp_type+'Tox21', self.args.DS)
        else:
            path = os.path.join(path, 'model_save', self.args.model, self.args.exp_type, self.args.DS)
        self.path = path
        os.makedirs(path, exist_ok=True)
        self.delete_files_in_directory(path)

    def delete_files_in_directory(self, directory):
        for filename in os.listdir(directory):
            file_path = os.path.join(directory, filename)
            if os.path.isfile(file_path):
                os.remove(file_path)
            elif os.path.isdir(file_path):
                self.delete_files_in_directory(file_path)

    def init_model(self, **kwargs):
        return dgmmd.DGMMD(
            in_dim=self.in_dim,
            hidden_dim=self.hidden_dim,
            out_dim=self.out_dim,
            num_layers=self.num_layers,
            bandwidths=self.bandwidths,
            dropout=self.dropout,
            batch_norm=self.batch_norm,
            kde_dimension=self.kde_dimension
        )

    def fit(self, dataset, args=None, label=None, dataloader=None, dataloader_val=None):
        self.device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
        self.model = self.init_model(**self.kwargs).to(self.device)
        optimizer = torch.optim.Adam(self.model.parameters(), lr=args.lr)
        criterion = AnomalyDetectionLoss()
        self.max_AUC = 0
        
        stop_counter = 0
        N = 30  # early stop threshold
        
        for epoch in range(1, args.num_epoch + 1):
            self.model.train()
            epoch_losses = []
            
            for data in dataloader:
                data = data.to(self.device)
                optimizer.zero_grad()
                
                # 计算正常样本的KDE分数
                train_distance_matrix = self.model.compute_distance_matrix(data)
                P_normal = self.model.compute_kde_scores(train_distance_matrix)
                
                # 生成负样本
                data_cpu = data.to('cpu')
                neg_graphs_sub = [self.neg_gen_sub.generate_negative_graph(g) for g in data_cpu]
                neg_graphs_add = [self.neg_gen_add.generate_negative_graph(g) for g in data_cpu]
                
                neg_sub_batch = dgmmd.batch_graphs(neg_graphs_sub).to(self.device)
                neg_add_batch = dgmmd.batch_graphs(neg_graphs_add).to(self.device)
                
                # 计算负样本分数
                neg_sub_distance_matrix = self.model.compute_distance_matrix(neg_sub_batch, data) 
                neg_add_distance_matrix = self.model.compute_distance_matrix(neg_add_batch, data)
                
                P_neg_sub = self.model.compute_kde_scores(neg_sub_distance_matrix)
                P_neg_add = self.model.compute_kde_scores(neg_add_distance_matrix)
                
                # 计算对比损失
                loss = 0.5 * criterion(P_normal, P_neg_sub) + 0.5 * criterion(P_normal, P_neg_add)
                
                loss.backward()
                optimizer.step()
                epoch_losses.append(loss.item())

            avg_loss = np.mean(epoch_losses)
            print(f'[TRAIN] Epoch:{epoch:03d} | Loss:{avg_loss:.4f}')

            if epoch % args.eval_freq == 0 and epoch > 0:
                self.model.eval()
                val_scores = []
                val_labels = []
                
                for data in dataloader_val:
                    data = data.to(self.device)
                    dist_matrix = self.model.compute_distance_matrix(data)
                    scores = self.model.compute_kde_scores(dist_matrix)
                    val_scores.extend(scores.cpu().detach().numpy())
                    val_labels.extend(data.y.cpu().numpy())
                
                val_auc = ood_auc(val_labels, -np.array(val_scores))
                
                if val_auc > self.max_AUC:
                    self.max_AUC = val_auc
                    stop_counter = 0
                    torch.save(self.model, os.path.join(self.path, 'model_DGMMDKDE.pth'))
                else:
                    stop_counter += 1
                
                print(f'Epoch:{epoch:03d} | val_auc:{val_auc:.4f}')
                if stop_counter >= N:
                    print(f'Early stopping triggered after {epoch} epochs')
                    break

        return self

    def decision_function(self, dataset, label=None, dataloader=None, args=None):
        if self.is_directory_empty(self.path):
            print("Can't find the path")
        else:
            self.model = torch.load(os.path.join(self.path, 'model_DGMMDKDE.pth'))
        
        self.model.eval()
        scores = []
        labels = []
        
        for data in dataloader:
            data = data.to(self.device)
            dist_matrix = self.model.compute_distance_matrix(data)
            batch_scores = self.model.compute_kde_scores(dist_matrix)
            scores.extend(batch_scores.cpu().detach().numpy())
            labels.extend(data.y.cpu().numpy())
            
        return scores, labels

    def predict(self, dataset=None, label=None, dataloader=None, args=None):
        output = ()
        if dataset is None:
            score = self.decision_score_
        else:
            score, y_all = self.decision_function(dataset, label, dataloader, args)
            output = (score, y_all)
        return output

if __name__ == "__main__":
    import torch
    from torch_geometric.data import Data
    
    # 创建一个简单的测试图
    x = torch.randn(5, 10)  # 5个节点，每个节点10维特征
    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4],
                              [1, 0, 2, 1, 3, 2, 4, 3]], dtype=torch.long)
    test_graph = Data(x=x, edge_index=edge_index)
    
    # 初始化负样本生成器
    neg_gen = NegativeSampleGenerator(
        perturbation_methods=['feature_swap', 'feature_noise', 'structure'],
        attr_swap_ratio=0.2,
        spectral_perturb_alpha=0.1,
        spectral_perturb_ratio=0.2,
        feature_noise_std=0.1,
        spectral_k=3,
        threshold=0.5
    )
    
    # 测试负样本生成
    neg_graph = neg_gen.generate_negative_graph(test_graph)
    
    print("原始图:")
    print(f"节点数: {test_graph.num_nodes}")
    print(f"边数: {test_graph.num_edges}")
    print(f"节点特征形状: {test_graph.x.shape}")
    print(f"边索引形状: {test_graph.edge_index.shape}")
    
    print("\n生成的负样本图:")
    print(f"节点数: {neg_graph.num_nodes}")
    print(f"边数: {neg_graph.num_edges}")
    print(f"节点特征形状: {neg_graph.x.shape}")
    print(f"边索引形状: {neg_graph.edge_index.shape}")
    
    # 测试DGMMDKDE模型
    model = DGMMDKDE(
        in_dim=10,
        hidden_dim=32,
        out_dim=32,
        num_layers=2,
        bandwidths=[0.1, 1.0, 10.0],
        dropout=0.1,
        batch_norm=True,
        kde_dimension=1
    )
    
    print("\nDGMMDKDE模型结构:")
    print(model)