import sys
import os

current_dir = os.path.dirname(__file__)
utils_path = os.path.abspath(os.path.join(current_dir, '..', 'utils'))
sys.path.append(utils_path)

from gnn_model import *
from utils import *

import torch
import numpy as np
import argparse
import scipy.sparse as ssp
from torch_sparse import SparseTensor


def train(model, graph, optimizer, batch_size, aggregation, margin=0.0):
    
    x = graph.x
    model.train()

    neg_edges = graph.train_mask  
    num_nodes = x.size(0)

    edge_index = graph.edge_index
    edge_weight = torch.ones(edge_index.size(1)).to(x.device)
    adj = SparseTensor.from_edge_index(edge_index, edge_weight, [num_nodes, num_nodes]).to(x.device)

    emb = model(x, adj)  
    h = torch.sigmoid(torch.mm(emb, emb.t()))
    
    optimizer.zero_grad()
    neg_edge_batch = neg_edges

    sorted_indices = torch.minimum(neg_edge_batch[:, 0, 0].long(), neg_edge_batch[:, 0, 1].long())
    larger_indices = torch.maximum(neg_edge_batch[:, 0, 0].long(), neg_edge_batch[:, 0, 1].long())
    if aggregation == "concatenation": neg_arc1 = torch.cat((h[sorted_indices], h[larger_indices]), dim=1)
    elif aggregation == "summation": neg_arc1 = h[sorted_indices] + h[larger_indices]

    sorted_indices = torch.minimum(neg_edge_batch[:, 1, 0].long(), neg_edge_batch[:, 1, 1].long())
    larger_indices = torch.maximum(neg_edge_batch[:, 1, 0].long(), neg_edge_batch[:, 1, 1].long())

    if aggregation == "concatenation": neg_arc2 = torch.cat((h[sorted_indices], h[larger_indices]), dim=1)
    elif aggregation == "summation": neg_arc2 = h[sorted_indices] + h[larger_indices]
  
    neg_sim = F.cosine_similarity(neg_arc1, neg_arc2, dim=-1)
    loss_sim = torch.clamp(neg_sim - margin, min=0).mean()  
    loss_sim.backward(retain_graph=True)
    
  
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    return loss_sim

def reliability_check(d, threshold=71.34, q=32):
    results = []
    for i in range(d.shape[1]):  
        d_i = d[:, i, :]     
        d_mean = torch.mean(d_i, dim=0)  
        S = torch.matmul((d_i - d_mean).T, (d_i - d_mean)) / (q - 1)  
        S += torch.eye(S.shape[0], device=S.device) * 1e-6  
        S_inv = torch.inverse(S)  
        T2_reliability = q * torch.matmul(torch.matmul(d_mean.T, S_inv), d_mean) 
        
        results.append(T2_reliability.item() < threshold)  

    return torch.tensor(results)  

def major_procedure(d, threshold=71.34, q=32):
    results = []
    for i in range(d.shape[1]):  
        d_i = d[:, i, :] 
        
        d_mean = torch.mean(d_i, dim=0)  
        S = torch.matmul((d_i - d_mean).T, (d_i - d_mean)) / (q - 1)  
        S += torch.eye(S.shape[0], device=S.device) * 1e-6  
        S_inv = torch.inverse(S)  
        T2_test = q * torch.matmul(torch.matmul(d_mean.T, S_inv), d_mean)  

        results.append(T2_test.item() > threshold)  

    return torch.tensor(results)  

@torch.no_grad()
def test(model, graph, val_test, batch_size, aggregation, margin=0.0, q=32):
    model.eval()
    differences, differences_isomorfic = [], []
    for k, single_graph in graph.items():
        x = single_graph.x

        if val_test == "val": neg_edges = single_graph.val_mask 
        else: neg_edges = single_graph.test_mask
        num_nodes = x.size(0)
        
        edge_index = single_graph.edge_index
        edge_weight = torch.ones(edge_index.size(1)).to(x.device)
        adj = SparseTensor.from_edge_index(edge_index, edge_weight, [num_nodes, num_nodes]).to(x.device)

        h = model(x, adj) 
        neg_edge_batch = neg_edges

        sorted_indices = torch.minimum(neg_edge_batch[:, 0, 0].long(), neg_edge_batch[:, 0, 1].long()).clone()
        larger_indices = torch.maximum(neg_edge_batch[:, 0, 0].long(), neg_edge_batch[:, 0, 1].long()).clone()
        if aggregation == "concatenation": neg_arc1 = torch.cat((h[sorted_indices], h[larger_indices]), dim=1)
        elif aggregation == "summation": neg_arc1 = h[sorted_indices] + h[larger_indices]
        
        sorted_indices = torch.minimum(neg_edge_batch[:, 1, 0].long(), neg_edge_batch[:, 1, 1].long()).clone()
        larger_indices = torch.maximum(neg_edge_batch[:, 1, 0].long(), neg_edge_batch[:, 1, 1].long()).clone()
        if aggregation == "concatenation": neg_arc2 = torch.cat((h[sorted_indices], h[larger_indices]), dim=1)
        elif aggregation == "summation": neg_arc2 = h[sorted_indices] + h[larger_indices]
            
        neg_sim = F.cosine_similarity(neg_arc1, neg_arc2, dim=-1)
        diff = neg_arc1 - neg_arc2
        differences.append(diff)
        loss_sim = torch.clamp(neg_sim - margin, min=0).mean()
        if k ==0:
            first = neg_arc1
        else:
            differences_isomorfic.append(first-neg_arc1)

    differences = torch.stack(differences, dim=0)
    differences = differences[:q]
    result = major_procedure(differences)
    
    differences_isomorfic=torch.stack(differences_isomorfic, dim=0)
    result_isomorfic = reliability_check(differences_isomorfic)
    
    accuracy = sum(result & result_isomorfic) / len(result)

    return loss_sim, accuracy


def main():
    parser = argparse.ArgumentParser(description='homo')
    parser.add_argument('--data_name', type=str, default='cora')
    parser.add_argument('--neg_mode', type=str, default='equal')
    parser.add_argument('--gnn_model', type=str, default='GCN')
    parser.add_argument('--score_model', type=str, default='mlp_score')
    parser.add_argument('--aggregation', type=str, default='concatenation')

    ##gnn setting
    parser.add_argument('--num_layers', type=int, default=1)
    parser.add_argument('--num_layers_predictor', type=int, default=2)
    parser.add_argument('--hidden_channels', type=int, default=64)
    parser.add_argument('--dropout', type=float, default=0.0)
    parser.add_argument('--patience', type=int, default=10)


    ### train setting
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--epochs', type=int, default=9999)
    parser.add_argument('--eval_steps', type=int, default=5)
    parser.add_argument('--runs', type=int, default=10)
    parser.add_argument('--kill_cnt',           dest='kill_cnt',      default=10,    type=int,       help='early stopping')
    parser.add_argument('--output_dir', type=str, default='output_test')
    parser.add_argument('--filename', type=str, default='dataset.pt')
    parser.add_argument('--l2',		type=float,             default=0.0,			help='L2 Regularization for Optimizer')
    parser.add_argument('--seed', type=int, default=999)
    
    parser.add_argument('--save', action='store_true', default=False)
    parser.add_argument('--use_saved_model', action='store_true', default=False)
    parser.add_argument('--metric', type=str, default='MRR')
    parser.add_argument('--device', type=int, default=1)
    parser.add_argument('--log_steps', type=int, default=1)

    ####### gin
    parser.add_argument('--gin_mlp_layer', type=int, default=2)

    #####
    parser.add_argument('--with_loss_weight', default=False, action='store_true')



    args = parser.parse_args()
   
    
    print(args.with_loss_weight)
    print(args)

    init_seed(args.seed)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset_path = os.path.join(current_dir+'/../data/', args.filename)
    graph = torch.load(dataset_path)


    for k, v in graph.items():
        edge_index_girato = v.edge_index.flip(0)
        edge_index_concatenato = torch.cat([graph[k].edge_index, edge_index_girato], dim=1)
        graph[k].edge_index = edge_index_concatenato
    for k, v in graph.items():
        graph[k].to(device)
    
    node_num = graph[0]['x'].size(0)
    input_channel = graph[0].x.size(1)
    model = GCN(input_channel, args.hidden_channels,
                    args.hidden_channels, args.num_layers, args.dropout).to(device)

    for run in range(args.runs):
        early_stopping = EarlyStopping(patience=args.patience)
        best_model_state = None  
        print('#################################          ', run, '          #################################')
        if args.runs == 1:
            seed = args.seed
        else:
            seed = run
        print('seed: ', seed)

        init_seed(seed)
        model.reset_parameters()

        optimizer = torch.optim.Adam(
                list(model.parameters()),lr=args.lr, weight_decay=args.l2)

        
        for epoch in range(args.epochs):
            loss = train(model, graph[0], optimizer, args.batch_size, args.aggregation, margin=0.0)
            print(f"Epoch {epoch+1}/{args.epochs}, Loss: {loss:.4f}")
            val_loss, val_accuracy = test(model, graph, "val", args.batch_size, args.aggregation, margin=0.0)
            print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")
            if early_stopping(val_loss):
                best_model_state = model.state_dict()
            
            if early_stopping.early_stop:
                break

        if best_model_state is not None:
            model.load_state_dict(best_model_state)
        test_loss, test_accuracy = test(model, graph, "test", args.batch_size, args.aggregation, margin=0.0)
        print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")
if __name__ == "__main__":
    main()
   