import contextlib
import time
import torch
import numpy as np
import pandas as pd
import time
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score,f1_score,roc_auc_score

import numpy as np
from sklearn.model_selection import train_test_split
from scipy.io import savemat
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import pairwise
from sklearn.decomposition import KernelPCA
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC 

def compute_Helmholtzians_Hodge_1_Laplacian(edge_index, B2, Nm=None, directed=True):
 
    edge_index = edge_index
    print(edge_index.shape)
    if Nm is None:
        Nm = torch.max(edge_index) + 1

    B1_row=[]
    B1_column=[]
    values=[]
    for i in range(edge_index.shape[1]):
        B1_row.extend([edge_index[0, i],edge_index[1, i]])
        B1_column.extend([i,i])
        values.extend([-1,1])
    B1_indices = torch.tensor([B1_row, B1_column], dtype=torch.long)
    B1_values = torch.tensor(values, dtype=torch.float)
    B1_shape = torch.Size([Nm, edge_index.shape[1]])

    B1= torch.sparse_coo_tensor(B1_indices, B1_values, B1_shape)

    print(B1.shape, B2.shape) 


    return torch.sparse.mm(B1.permute(1, 0), B1) + torch.sparse.mm(B2, B2.permute(1, 0)),B1


if __name__ == "__main__":
    datasets=['soc-sign-bitcoinalpha','soc-sign-bitcoinotc','Slashdot','epinions']
    M=[50,100,150,200,250,300]
    M=[300]
    for k in range(len(datasets)):
        data_name=datasets[k]
        data = pd.read_csv(f"./data/{data_name}.csv",header=None)
        if data_name in ['soc-sign-bitcoinalpha','soc-sign-bitcoinotc']:
            data.columns = ['SOURCE', 'TARGET', 'RATING','t']
        else:
            data.columns = ['SOURCE', 'TARGET', 'RATING'] 
        num_edges = len(data)
        edges_index = {}
        nodes = set(data['SOURCE']) | set(data['TARGET'])
        node_dict = {node: index for index, node in enumerate(nodes)}
        num_nodes = len(nodes)
        classes=[]
        graph = {}
        pos_in_degree = torch.zeros(num_nodes, 1, dtype=torch.float)
        neg_in_degree = torch.zeros(num_nodes, 1, dtype=torch.float)
        pos_out_degree = torch.zeros(num_nodes, 1, dtype=torch.float)
        neg_out_degree = torch.zeros(num_nodes, 1, dtype=torch.float)
        F1s,Accuracys, Costs,F1_macros,AUCs = np.zeros((5,)) ,np.zeros((5,)) , np.zeros((5,)) , np.zeros((5,)) , np.zeros((5,)) 
        for turn in range(1):
            for j in range(len(M)):
                m=M[j]
                start=time.time()
                for index, row in data.iterrows():
                    source = row['SOURCE']
                    target = row['TARGET']
                    rating = row['RATING']

                    if rating > 0:
                        category = 1
                    else :
                        category = 0
                
                    source_index = node_dict[source]
                    target_index = node_dict[target]
                    if source_index not in graph:
                        graph[source_index] = set()
                    graph[source_index].add(target_index)
                    if target_index not in graph:
                        graph[target_index] = set()
                    graph[target_index].add(source_index)
                    if (source_index, target_index) in edges_index:
                        # print(f"Duplicate edge found: ({source_index}, {target_index})")
                        continue
                    else:
                        edges_index[(source_index, target_index)] = index
                       
                    if category == 1:
                        pos_out_degree[source_index] += 1
                        pos_in_degree[target_index] += 1
                    elif category == -1:
                        neg_out_degree[source_index] += 1
                        neg_in_degree[target_index] += 1
                    classes.append(category)
                edge_index = torch.from_numpy(np.array(list(edges_index.keys())))  
                triangles = set()
                triangle_num=0
                B2_row=[]
                B2_column=[]
                values=[]
            
                for u in range(num_nodes):
                    neighbors_u = graph[u]
            
                    for v in neighbors_u:
                        if u < v:  # Ensure each triangle is listed once
                            neighbors_v = graph[v]
                            common_neighbors = neighbors_u.intersection(neighbors_v)
                            for w in common_neighbors:
                                if u < v < w:  # Ensure a consistent ordering
                                    
                                    triangles.add((u, v, w))
                                    edge_index_uv = edges_index.get((u, v), None)
                                    if edge_index_uv is None:
                                        edge_index_uv = edges_index.get((v, u), None)
                                        if edge_index_uv is None:
                                         
                                            continue 

                                    edge_index_uw = edges_index.get((u, w), None)
                                    if edge_index_uw is None:
                                        edge_index_uw = edges_index.get((w, u), None)
                                        if edge_index_uw is None:
                                            continue 
                                    edge_index_vw = edges_index.get((v, w), None)
                                    if edge_index_vw is None:
                                        edge_index_vw = edges_index.get((w, v), None)
                                        if edge_index_vw is None:
                                            continue 

                                    B2_row.extend([edge_index_uv, edge_index_uw, edge_index_vw])
                                    B2_column.extend([triangle_num,triangle_num,triangle_num])
                                    values.extend([1,-1,1])
                                    triangle_num = triangle_num + 1
                B2_indices = torch.tensor([B2_row, B2_column], dtype=torch.long)
                B2_values = torch.tensor(values, dtype=torch.float)
                B2_shape = torch.Size([num_edges, triangle_num])
                B2= torch.sparse_coo_tensor(B2_indices, B2_values, B2_shape)        
                L,B1=compute_Helmholtzians_Hodge_1_Laplacian(edge_index.T,B2,Nm=num_nodes,directed=True)
                node_feature_tensor = torch.cat((pos_in_degree, neg_in_degree, pos_out_degree, neg_out_degree), dim=1)
                edge_feature_tensor = torch.zeros((num_edges, 8), dtype=torch.float)
                for i in range(num_edges):
                    u, v = edge_index[i]
                    #edge_feature_tensor[i] = node_feature_tensor[u] + node_feature_tensor[v]
                    edge_feature_tensor[i] = torch.cat((node_feature_tensor[u],node_feature_tensor[v]),dim=0)

                triange_feature_tensor = torch.zeros((len(triangles), 12), dtype=torch.float)
            
                for i, (u, v, w) in enumerate(triangles):
                    # triange_feature_tensor[i] = node_feature_tensor[u] + node_feature_tensor[v] + node_feature_tensor[w]
                    triange_feature_tensor[i] = torch.cat((node_feature_tensor[u], node_feature_tensor[v], node_feature_tensor[w]), dim=0)
         
                for t in range(1,6):
                    
                    start_time = time.time()
                    for i in range(t):
                        if i == 0:
                            edge_feature_tensor_tmp = edge_feature_tensor.clone()
                            node_feature_tensor_tmp = node_feature_tensor.clone()
                            triange_feature_tensor_tmp = triange_feature_tensor.clone()
                        W1 = torch.randn((node_feature_tensor_tmp.size(1), m), dtype=torch.float)
                        W2 = torch.randn((edge_feature_tensor_tmp.size(1), m), dtype=torch.float)
                        W3 = torch.randn((triange_feature_tensor_tmp.size(1), m), dtype=torch.float)

                        node_feature_tensor_tmp = (torch.mm(node_feature_tensor_tmp, W1)>0).float()
                        triange_feature_tensor_tmp =  (torch.mm(triange_feature_tensor_tmp, W3)>0).float()
                        node2edge = torch.sparse.mm(B1.t(),node_feature_tensor_tmp)
                        edge_feature_tensor_tmp = (torch.sparse.mm(L, torch.mm(edge_feature_tensor_tmp, W2))>0).float()
                        triangle2edge = torch.mm(B2,triange_feature_tensor_tmp)
            
                        edge_feature_tensor_tmp = ((node2edge + edge_feature_tensor_tmp + triangle2edge)>0).float()
                        embeds=edge_feature_tensor_tmp.numpy()
                        
                    cosine_kernel_matrix = cosine_similarity(embeds)

                    # 划分训练集和测试集索引
                    train_idx, test_idx, train_Y, test_Y = train_test_split(range(len(embeds)), classes, test_size=0.2, random_state=42)

                    # 生成训练集和测试集的核矩阵
                    train_X = cosine_kernel_matrix[train_idx][:, train_idx]  # 训练样本之间的相似度矩阵
                    test_X = cosine_kernel_matrix[test_idx][:, train_idx]    # 测试样本和训练样本之间的相似度矩阵

                    # 初始化SVM模型，使用预计算的核矩阵
                    clf = SVC(kernel='precomputed')

                    # 训练SVM模型
                    clf.fit(train_X, train_Y)

                    # 使用测试集进行预测
                    y_pred = clf.predict(test_X)

                    # 计算F1分数和准确率
                    f1_binary = f1_score(y_true=test_Y, y_pred=y_pred, average='binary')  # binary F1 分数
                    f1_macro = f1_score(y_true=test_Y, y_pred=y_pred, average='macro')    # macro F1 分数
                    accuracy = accuracy_score(y_true=test_Y, y_pred=y_pred)
                    auc = roc_auc_score(y_true=test_Y, y_score=clf.decision_function(test_X))  # AUC 分数

                    # 存储结果
                    F1s[t-1], F1_macros[t-1], Accuracys[t-1], AUCs[t-1], Costs[t-1] = f1_binary, f1_macro, accuracy, auc, time.time()-start_time
                    print(f"Binary F1 Score: {f1_binary}, Macro F1 Score: {f1_macro}, Accuracy: {accuracy}, AUC: {auc}, Cost: {time.time()-start_time}")
                    

                  

    #     results_dict = {
    #     'F1s': F1s,
    #     'Accuracys': Accuracys,

    #     'cost':Costs
       
       
    # }
    #     savemat(f'./result/{data_name}_cosine.mat', results_dict)