import contextlib
import time
import torch
import numpy as np
import pandas as pd
import time
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score,f1_score,roc_auc_score

import numpy as np
from sklearn.model_selection import train_test_split
from scipy.io import savemat
def compute_Helmholtzians_Hodge_1_Laplacian(edge_index, B2, Nm=None, directed=True):
 
    edge_index = edge_index
    print(edge_index.shape)
    if Nm is None:
        Nm = torch.max(edge_index) + 1

    B1_row=[]
    B1_column=[]
    values=[]
    for i in range(edge_index.shape[1]):
        B1_row.extend([edge_index[0, i],edge_index[1, i]])
        B1_column.extend([i,i])
        values.extend([-1,1])
    B1_indices = torch.tensor([B1_row, B1_column], dtype=torch.long)
    B1_values = torch.tensor(values, dtype=torch.float)
    B1_shape = torch.Size([Nm, edge_index.shape[1]])

    B1= torch.sparse_coo_tensor(B1_indices, B1_values, B1_shape)

    print(B1.shape, B2.shape) 


    return torch.sparse.mm(B1.permute(1, 0), B1) + torch.sparse.mm(B2, B2.permute(1, 0)),B1


if __name__ == "__main__":
    datasets=['soc-sign-bitcoinalpha','soc-sign-bitcoinotc','Slashdot','epinions']
    M=[50,100,150,200,250,300]
    M=[300]
    for k in range(len(datasets)):
        data_name=datasets[k]
        data = pd.read_csv(f"./data/{data_name}.csv",header=None)
        if data_name in ['soc-sign-bitcoinalpha','soc-sign-bitcoinotc']:
            data.columns = ['SOURCE', 'TARGET', 'RATING','t']
            t=1

        else:
            data.columns = ['SOURCE', 'TARGET', 'RATING'] 
            if data_name=='Slashdot':
                t=2
            else:
                t=1
        num_edges = len(data)
        edges_index = {}
        nodes = set(data['SOURCE']) | set(data['TARGET'])
        node_dict = {node: index for index, node in enumerate(nodes)}
        num_nodes = len(nodes)
        classes=[]
        graph = {}
        pos_in_degree = torch.zeros(num_nodes, 1, dtype=torch.float)
        neg_in_degree = torch.zeros(num_nodes, 1, dtype=torch.float)
        pos_out_degree = torch.zeros(num_nodes, 1, dtype=torch.float)
        neg_out_degree = torch.zeros(num_nodes, 1, dtype=torch.float)
        F1s,Accuracys, Costs = np.zeros((5,len(M), 5)) ,np.zeros((5,len(M), 5)) , np.zeros((5,len(M), 5)) 
        for turn in range(1):
            for j in range(len(M)):
                m=M[j]
                start=time.time()
                for index, row in data.iterrows():
                    source = row['SOURCE']
                    target = row['TARGET']
                    rating = row['RATING']

                    if rating > 0:
                        category = 1
                    else :
                        category = -1
                
                    source_index = node_dict[source]
                    target_index = node_dict[target]
                    if source_index not in graph:
                        graph[source_index] = set()
                    graph[source_index].add(target_index)
                    if target_index not in graph:
                        graph[target_index] = set()
                    graph[target_index].add(source_index)
                    if (source_index, target_index) in edges_index:
                        # print(f"Duplicate edge found: ({source_index}, {target_index})")
                        continue
                    else:
                        edges_index[(source_index, target_index)] = index
                       
                    if category == 1:
                        pos_out_degree[source_index] += 1
                        pos_in_degree[target_index] += 1
                    elif category == -1:
                        neg_out_degree[source_index] += 1
                        neg_in_degree[target_index] += 1
                    classes.append(category)
                edge_index = torch.from_numpy(np.array(list(edges_index.keys())))  
                triangles = set()
                triangle_num=0
                B2_row=[]
                B2_column=[]
                values=[]
            
                for u in range(num_nodes):
                    neighbors_u = graph[u]
            
                    for v in neighbors_u:
                        if u < v:  # Ensure each triangle is listed once
                            neighbors_v = graph[v]
                            common_neighbors = neighbors_u.intersection(neighbors_v)
                            for w in common_neighbors:
                                if u < v < w:  # Ensure a consistent ordering
                                    
                                    triangles.add((u, v, w))
                                    edge_index_uv = edges_index.get((u, v), None)
                                    if edge_index_uv is None:
                                        edge_index_uv = edges_index.get((v, u), None)
                                        if edge_index_uv is None:
                                         
                                            continue 

                                    edge_index_uw = edges_index.get((u, w), None)
                                    if edge_index_uw is None:
                                        edge_index_uw = edges_index.get((w, u), None)
                                        if edge_index_uw is None:
                                            continue 
                                    edge_index_vw = edges_index.get((v, w), None)
                                    if edge_index_vw is None:
                                        edge_index_vw = edges_index.get((w, v), None)
                                        if edge_index_vw is None:
                                            continue 

                                    B2_row.extend([edge_index_uv, edge_index_uw, edge_index_vw])
                                    B2_column.extend([triangle_num,triangle_num,triangle_num])
                                    values.extend([1,-1,1])
                                    triangle_num = triangle_num + 1
                B2_indices = torch.tensor([B2_row, B2_column], dtype=torch.long)
                B2_values = torch.tensor(values, dtype=torch.float)
                B2_shape = torch.Size([num_edges, triangle_num])
                B2= torch.sparse_coo_tensor(B2_indices, B2_values, B2_shape)        
                L,B1=compute_Helmholtzians_Hodge_1_Laplacian(edge_index.T,B2,Nm=num_nodes,directed=True)
                node_feature_tensor = torch.cat((pos_in_degree, neg_in_degree, pos_out_degree, neg_out_degree), dim=1)
                edge_feature_tensor = torch.zeros((num_edges, 8), dtype=torch.float)
               
                train_X, test_X, train_Y, test_Y = train_test_split(edge_feature_tensor, classes, test_size=0.2)
                clf = LogisticRegression()
                clf.fit(train_X, train_Y)
                train_end=time.time()
                train_time=train_end-start
                y_pred = clf.predict(test_X)
                macro_f1 = f1_score(y_true=test_Y, y_pred=y_pred, average='macro')
                binary_f1 = f1_score(y_true=test_Y, y_pred=y_pred, average='binary')
                accuracy = accuracy_score(y_true=test_Y, y_pred=y_pred)
                # Use predict_proba to get probabilities for AUC calculation
                auc = roc_auc_score(test_Y, clf.predict_proba(test_X)[:,1])  
                print(f"AUC: {auc},F1:{binary_f1} Mean Accuracy: {accuracy}, Macro_f1: {macro_f1}")
                infer_time=time.time()-train_end
                # cost = time.time() - start_time 
                # F1s[turn,j,t-1]=f1
                # Accuracys[turn,j,t-1]=accuracy
                # Costs[turn,j,t-1] = cost
                print("train_time:{},infer_time:{}".format(train_time,infer_time))