import sys
import os

current_dir = os.path.dirname(__file__)
utils_path = os.path.abspath(os.path.join(current_dir, '..', 'utils'))
sys.path.append(utils_path)

from gnn_model import *
from utils import *
from NEOGNN.neognn import NeoGNN
import torch
import numpy as np
import argparse
import scipy.sparse as ssp
import copy

from torch_sparse import SparseTensor


def train(model, graph, optimizer, aggregation, args, device):
    
    x = graph.x
    model.train()
    neg_edges = graph.train_mask  

    num_nodes = x.size(0)
    adj = SparseTensor.from_edge_index(graph.edge_index, graph.edge_weight, [num_nodes, num_nodes]).to(device)
    edge_0, edge_1 = [], []
    for pairs in graph.train_mask:
        for edge in pairs:
            edge_0.append(edge[0].item())
            edge_1.append(edge[1].item())
    edge_index = torch.tensor([edge_0, edge_1], device=device)

    h = model(edge_index,  adj, graph.A, x, num_nodes, aggregation, only_feature=args.only_use_feature)
    h = h[1]
    optimizer.zero_grad()
    neg_arc1 = h[::2]
    neg_arc2 = h[1::2]
    neg_sim = F.cosine_similarity(neg_arc1, neg_arc2, dim=-1)
    loss_sim = torch.clamp(neg_sim, min=0).mean()

    loss_sim.backward(retain_graph=True)
  
    optimizer.step()
   
    return loss_sim

def reliability_check(d, threshold=71.34, q=32):
    """
    Verifica se il modello è affidabile confrontando le differenze tra versioni isomorfe di un grafo.

    Parametri:
    - d: tensore delle differenze con shape (q, d), ottenuto da coppie di grafi isomorfi.
    - threshold: valore di soglia per la statistica T².
    - q: numero di campioni.

    Ritorna:
    - True se la variazione interna è sotto la soglia (affidabile), False altrimenti.
    """
    results = []
    for i in range(d.shape[1]):  # Itera sulle due matrici (indice 1)
        d_i = d[:, i, :]  # Seleziona la i-esima matrice di differenze    
        d_mean = torch.mean(d_i, dim=0)  # Media delle differenze
        S = torch.matmul((d_i - d_mean).T, (d_i - d_mean)) / (q - 1)  # Matrice di covarianza
        S += torch.eye(S.shape[0], device=S.device) * 1e-7  # Regularizzazione per evitare problemi di inversione
        S_inv = torch.inverse(S)  # Inversa della matrice di covarianza
        T2_reliability = q * torch.matmul(torch.matmul(d_mean.T, S_inv), d_mean)  # Statistica T^2
        
        results.append(T2_reliability.item() < threshold)  # Verifica il superamento della soglia

    return torch.tensor(results)  # Il test è affidabile se T² è sotto la soglia

def major_procedure(d, threshold=71.34, q=32):
    """
    Valuta se il modello distingue due grafi non isomorfi usando il test T^2.
    
    Parametri:
    - d: tensore di differenze con shape (n, 2, d), dove `n` è il numero di test.
    - threshold: valore di soglia.
    - q: numero di campioni per matrice d.

    Ritorna:
    - Lista di booleani, uno per ciascun test.
    """
    results = []
    # print(d.shape)
    for i in range(d.shape[1]):  # Itera sulle due matrici (indice 1)
        d_i = d[:, i, :]  # Seleziona la i-esima matrice di differenze
        d_mean = torch.mean(d_i, dim=0)  # Media delle differenze
        S = torch.matmul((d_i - d_mean).T, (d_i - d_mean)) / (q - 1)  # Matrice di covarianza
        S += torch.eye(S.shape[0], device=S.device) * 1e-7  # Regularization per evitare S singolare
        S_inv = torch.inverse(S)  # Inversa di S
        T2_test = q * torch.matmul(torch.matmul(d_mean.T, S_inv), d_mean)  # Statistica T^2
        results.append(T2_test.item() > threshold)  # Verifica il superamento della soglia
    #     print(d_mean, T2_test.item(), T2_test.item() > threshold)
    # exit()
    return torch.tensor(results)  # Lista di booleani

@torch.no_grad()
def test(model, graph, val_test, aggregation, args, device, q=32):
    model.eval()
    differences, differences_isomorfic, losses = [], [], []
    for k, single_graph in graph.items():
        if val_test == "val": neg_edges = single_graph.val_mask 
        else: neg_edges = single_graph.test_mask

        x = single_graph.x
        num_nodes = x.size(0)
        adj = SparseTensor.from_edge_index(single_graph.edge_index, single_graph.edge_weight, [num_nodes, num_nodes]).to(device)
        edge_0, edge_1 = [], []
        
        for pairs in neg_edges:
            for edge in pairs:
                edge_0.append(edge[0].item())
                edge_1.append(edge[1].item())
        edge_index = torch.tensor([edge_0, edge_1], device=device)
     
        h = model(edge_index,  adj, single_graph.A, x, num_nodes, aggregation, only_feature=args.only_use_feature)
        h = h[1]
        neg_arc1 = h[::2]
        neg_arc2 = h[1::2]
        
        neg_sim = F.cosine_similarity(neg_arc1, neg_arc2, dim=-1)
      
        diff = neg_arc1 - neg_arc2
        
        differences.append(diff)
   
        loss_sim = torch.clamp(neg_sim, min=0).mean()
        losses.append(loss_sim.detach().cpu().numpy())
        if k ==0:
            first = neg_arc1
        else:
            differences_isomorfic.append(first-neg_arc1)
      
    differences = torch.stack(differences, dim=0)
    differences = differences[:q]
    result = major_procedure(differences)
    
    differences_isomorfic=torch.stack(differences_isomorfic, dim=0)
    result_isomorfic = reliability_check(differences_isomorfic)
    
    accuracy = sum(result & result_isomorfic) / len(result)
 
    return np.mean(losses), accuracy

def data_to_device(data, device):

    data['adj'] = data['adj'].to(device)
    data['train_pos'] = data['train_pos'].to(device)
    data['train_val'] = data['train_val'].to(device)
    data['valid_pos'] = data['valid_pos'].to(device)
    data['valid_neg'] = data['valid_neg'].to(device)
    data['test_pos'] = data['test_pos'].to(device)
    data['test_neg'] = data['test_neg'].to(device)
    data['x'] = data['x'].to(device)

    return data


def main():
    parser = argparse.ArgumentParser(description='homo')
    parser.add_argument('--data_name', type=str, default='cora')
    parser.add_argument('--neg_mode', type=str, default='equal')
    parser.add_argument('--gnn_model', type=str, default='NeoGNN')
    parser.add_argument('--score_model', type=str, default='mlp_score')
    parser.add_argument('--aggregation', type=str, default='summation')

    ##gnn setting
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--num_layers_predictor', type=int, default=2)
    parser.add_argument('--hidden_channels', type=int, default=64)
    parser.add_argument('--dropout', type=float, default=0.5)


    ### train setting
    parser.add_argument('--batch_size', type=int, default=1024)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--epochs', type=int, default=300)
    parser.add_argument('--eval_steps', type=int, default=5)
    parser.add_argument('--patience', type=int, default=300)
    parser.add_argument('--runs', type=int, default=5)
    parser.add_argument('--kill_cnt',           dest='kill_cnt',      default=10,    type=int,       help='early stopping')
    parser.add_argument('--output_dir', type=str, default='output_test')
    parser.add_argument('--filename', type=str, default='dataset.pt')
    parser.add_argument('--l2',		type=float,             default=0.0,			help='L2 Regularization for Optimizer')
    parser.add_argument('--seed', type=int, default=999)

    parser.add_argument('--save', action='store_true', default=False)
    parser.add_argument('--use_saved_model', action='store_true', default=False)
    parser.add_argument('--metric', type=str, default='MRR')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)
    
    ####### gin
    parser.add_argument('--gin_mlp_layer', type=int, default=2)

    ######gat
    parser.add_argument('--gat_head', type=int, default=1)

    ######mf
    parser.add_argument('--cat_node_feat_mf', default=False, action='store_true')

    ######neo-gnn
    parser.add_argument('--f_edge_dim', type=int, default=8) 
    parser.add_argument('--f_node_dim', type=int, default=128) 
    parser.add_argument('--g_phi_dim', type=int, default=128) 
    parser.add_argument('--only_use_feature',	action='store_true',   default=False,   	help='whether only use the feature')
    parser.add_argument('--beta', type=float, default=0.1)

    parser.add_argument('--eval_mrr_data_name', type=str, default='ogbl-citation2')
	



    args = parser.parse_args()
   
    print('cat_node_feat_mf: ', args.cat_node_feat_mf)
    print(args)


    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset_path = os.path.join(current_dir+'/../data/', args.filename)
    graph = torch.load(dataset_path)


    for k, v in graph.items():
        graph[k].edge_index = torch.cat([graph[k].edge_index, graph[k].edge_index.flip(0)], dim=1)  # Rendi il grafo non orientato
        node_num = graph[k].x.size(0)
        A = ssp.csr_matrix((torch.ones(graph[k].edge_index.size(1), dtype=float), (graph[k].edge_index[0], graph[k].edge_index[1])), shape=(node_num, node_num))
        A2 = A * A
        A = A + args.beta*A2
        graph[k].A = A
        graph[k].to(device)
        
        
    input_channel = graph[0].x.size(1)
    node_num = graph[0].x.size(0)

    test_accuracies = []
    for run in range(args.runs):
        early_stopping = EarlyStopping(patience=args.patience)
        init_seed(run)
        model = NeoGNN(input_channel, args.hidden_channels,
                    args.hidden_channels, args.num_layers, args.dropout,args).to(device)
        model.reset_parameters()
        print('#################################          ', run, '          #################################')

        optimizer = torch.optim.Adam(
                list(model.parameters()),lr=args.lr, weight_decay=args.l2)

        prev_loss = 100
        for epoch in range(1, 1 + args.epochs):
            if epoch == 1:
                start_time = time.time()
            # loss = train(model, score_func, A, data['train_pos'], data['x'], optimizer, args.batch_size, args)
            train_loss = train(model, graph[0], optimizer, args.aggregation, args, device)
            print(f"Epoch: {epoch}, train loss: {train_loss.item()}")
            val_loss, val_acc = test(model, graph, "val", args.aggregation, args, device)
            print(f"Validation Loss: {val_loss:.4f}, accuracy: {val_acc:.4f}")
            if early_stopping(val_loss):
                best_model_state = model.state_dict()
            
            if early_stopping.early_stop:
                break
        if best_model_state is not None:
            model.load_state_dict(best_model_state)
        test_loss, test_acc = test(model, graph, "test", args.aggregation, args, device)
        print(f"Test acc: {test_acc}")
        test_accuracies.append(test_acc)
    print(f"Test accuracy over 5 runs: {np.mean(test_accuracies)} ± {np.std(test_accuracies)}")

if __name__ == "__main__":
    main()
   