
import sys
import os

current_dir = os.path.dirname(__file__)
utils_path = os.path.abspath(os.path.join(current_dir, '..', 'utils'))
sys.path.append(utils_path)

from gnn_model import *
from utils import *

import torch
import numpy as np
import argparse
import scipy.sparse as ssp
from torch_sparse import SparseTensor
import networkx as nx
from torch_sparse import SparseTensor
         

def train(model, graph, optimizer, batch_size, aggregation, margin=0.0):
    
    x = graph.x
    model.train()

    neg_edges = graph.train_mask  
    num_nodes = x.size(0)

    edge_index = graph.edge_index
    edge_weight = torch.ones(edge_index.size(1)).to(x.device)
    adj = SparseTensor.from_edge_index(edge_index, edge_weight, [num_nodes, num_nodes]).to(x.device)

    h = model(x, adj)  

    optimizer.zero_grad()
    neg_edge_batch = neg_edges

    sorted_indices = torch.minimum(neg_edge_batch[:, 0, 0].long(), neg_edge_batch[:, 0, 1].long())
    larger_indices = torch.maximum(neg_edge_batch[:, 0, 0].long(), neg_edge_batch[:, 0, 1].long())
    if aggregation == "concatenation": neg_arc1 = torch.cat((h[sorted_indices], h[larger_indices]), dim=1)
    elif aggregation == "summation": neg_arc1 = h[sorted_indices] + h[larger_indices]

    sorted_indices = torch.minimum(neg_edge_batch[:, 1, 0].long(), neg_edge_batch[:, 1, 1].long())
    larger_indices = torch.maximum(neg_edge_batch[:, 1, 0].long(), neg_edge_batch[:, 1, 1].long())

    if aggregation == "concatenation": neg_arc2 = torch.cat((h[sorted_indices], h[larger_indices]), dim=1)
    elif aggregation == "summation": neg_arc2 = h[sorted_indices] + h[larger_indices]
  
    neg_sim = F.cosine_similarity(neg_arc1, neg_arc2, dim=-1)
    loss_sim = torch.clamp(neg_sim - margin, min=0).mean()  
    loss_sim.backward(retain_graph=True)
    
  
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    return loss_sim

def reliability_check(d, threshold=71.34, q=32):
    results = []
    for i in range(d.shape[1]): 
        d_i = d[:, i, :]    
        d_mean = torch.mean(d_i, dim=0)  
        S = torch.matmul((d_i - d_mean).T, (d_i - d_mean)) / (q - 1)  
        S += torch.eye(S.shape[0]) * 1e-6  
        S_inv = torch.inverse(S)  
        T2_reliability = q * torch.matmul(torch.matmul(d_mean.T, S_inv), d_mean)  
        
        results.append(T2_reliability.item() < threshold)  

    return torch.tensor(results)  

def major_procedure(d, threshold=71.34, q=32):
    results = []
    for i in range(d.shape[1]): 
        d_i = d[:, i, :]  
        d_mean = torch.mean(d_i, dim=0)  
        S = torch.matmul((d_i - d_mean).T, (d_i - d_mean)) / (q - 1)  
        S += torch.eye(S.shape[0]) * 1e-6  
        S_inv = torch.inverse(S)  
        T2_test = q * torch.matmul(torch.matmul(d_mean.T, S_inv), d_mean) 
        results.append(T2_test.item() > threshold)  
    return torch.tensor(results) 

@torch.no_grad()
def test(model, graph, val_test, batch_size, aggregation, margin=0.0, q=32):
    model.eval()
    differences, differences_isomorfic = [], []
    for k, single_graph in graph.items():
        x = single_graph.x

        if val_test == "val": neg_edges = single_graph.val_mask 
        else: neg_edges = single_graph.test_mask
        num_nodes = x.size(0)
    
        edge_index = single_graph.edge_index
        edge_weight = torch.ones(edge_index.size(1)).to(x.device)
        adj = SparseTensor.from_edge_index(edge_index, edge_weight, [num_nodes, num_nodes]).to(x.device)

        h = model(x, adj) 
        neg_edge_batch = neg_edges
        sorted_indices = torch.minimum(neg_edge_batch[:, 0, 0].long(), neg_edge_batch[:, 0, 1].long()).clone()
        larger_indices = torch.maximum(neg_edge_batch[:, 0, 0].long(), neg_edge_batch[:, 0, 1].long()).clone()
        if aggregation == "concatenation": neg_arc1 = torch.cat((h[sorted_indices], h[larger_indices]), dim=1)
        elif aggregation == "summation": neg_arc1 = h[sorted_indices] + h[larger_indices]
        
        sorted_indices = torch.minimum(neg_edge_batch[:, 1, 0].long(), neg_edge_batch[:, 1, 1].long()).clone()
        larger_indices = torch.maximum(neg_edge_batch[:, 1, 0].long(), neg_edge_batch[:, 1, 1].long()).clone()
        if aggregation == "concatenation": neg_arc2 = torch.cat((h[sorted_indices], h[larger_indices]), dim=1)
        elif aggregation == "summation": neg_arc2 = h[sorted_indices] + h[larger_indices]
            
        neg_sim = F.cosine_similarity(neg_arc1, neg_arc2, dim=-1)
        diff = neg_arc1 - neg_arc2
        differences.append(diff)
        loss_sim = torch.clamp(neg_sim - margin, min=0).mean()
        if k ==0:
            first = neg_arc1
        else:
            differences_isomorfic.append(first-neg_arc1)
      
    differences = torch.stack(differences, dim=0)
    differences = differences[:q]
    result = major_procedure(differences)
    
    differences_isomorfic=torch.stack(differences_isomorfic, dim=0)
    result_isomorfic = reliability_check(differences_isomorfic)
    
    accuracy = sum(result & result_isomorfic) / len(result)
    
    return loss_sim, accuracy

def main():
    parser = argparse.ArgumentParser(description='homo')
    parser.add_argument('--data_name', type=str, default='citeseer')
    parser.add_argument('--neg_mode', type=str, default='equal')
    parser.add_argument('--gnn_model', type=str, default='SAGE')
    parser.add_argument('--score_model', type=str, default='mlp_score')

    ##gnn setting
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--num_layers_predictor', type=int, default=2)
    parser.add_argument('--hidden_channels', type=int, default=4)
    parser.add_argument('--dropout', type=float, default=0.0)
    parser.add_argument('--aggregation', type=str, default='concatenation')


    ### train setting
    parser.add_argument('--batch_size', type=int, default=1024)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--epochs', type=int, default=2000)
    parser.add_argument('--eval_steps', type=int, default=5)
    parser.add_argument('--patience', type=int, default=10)
    parser.add_argument('--runs', type=int, default=5)
    parser.add_argument('--output_dir', type=str, default='output_test')
    parser.add_argument('--filename', type=str, default='dataset.pt')
    parser.add_argument('--l2',		type=float,             default=0.0,			help='L2 Regularization for Optimizer')
    parser.add_argument('--seed', type=int, default=999)
    
    parser.add_argument('--save', action='store_true', default=False)
    parser.add_argument('--use_saved_model', action='store_true', default=False)
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)
    ####### gin
    parser.add_argument('--gin_mlp_layer', type=int, default=2)
    ######gat
    parser.add_argument('--gat_head', type=int, default=1)
    ######mf
    parser.add_argument('--cat_node_feat_mf', default=False, action='store_true')
    

    args = parser.parse_args()
    init_seed(args.seed)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = 'cpu'
    device = torch.device(device)

    dataset_path = os.path.join(current_dir+'/../data/', args.filename)
    
    graph = torch.load(dataset_path)
    for k, v in graph.items():
        edge_index_girato = v.edge_index.flip(0)
        edge_index_concatenato = torch.cat([graph[k].edge_index, edge_index_girato], dim=1)
        graph[k].edge_index = edge_index_concatenato
    
    for k, v in graph.items():
        graph[k].to(device)
        
    node_num = graph[0]['x'].size(0)
    
    input_channel = graph[0].x.size(1)

    model = eval(args.gnn_model)(input_channel, args.hidden_channels,
                        args.hidden_channels, args.num_layers, args.dropout, args.gin_mlp_layer, args.gat_head, node_num, args.cat_node_feat_mf).to(device)
    
    test_accuracies = []
    for run in range(args.runs):
        early_stopping = EarlyStopping(patience=args.patience)
        best_model_state = None  
        if args.runs == 1:
            seed = args.seed
        else:
            seed = run
        print('seed: ', seed)
        init_seed(seed)
        
        model.reset_parameters()
        optimizer = torch.optim.Adam(
                list(model.parameters()),lr=args.lr, weight_decay=args.l2)

        for epoch in range(args.epochs):
            loss = train(model, graph[0], optimizer, args.batch_size, args.aggregation, margin=0.0)
            print(f"Epoch {epoch+1}/{args.epochs}, Loss: {loss:.4f}")
            
            val_loss, val_accuracy = test(model, graph, "val", args.batch_size, args.aggregation, margin=0.0)
            print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f},")
            if early_stopping(val_loss):
                best_model_state = model.state_dict()
            
            if early_stopping.early_stop:
                break

        if best_model_state is not None:
            model.load_state_dict(best_model_state)
        test_loss, test_accuracy = test(model, graph, "test", args.batch_size, args.aggregation, margin=0.0)
        print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")
        test_accuracies.append(test_accuracy)
    print(f"Test accuracy over 5 runs: {np.mean(test_accuracies)} ± {np.std(test_accuracies)}")

if __name__ == "__main__":
    main()
   