import sys
import os

current_dir = os.path.dirname(__file__)
utils_path = os.path.abspath(os.path.join(current_dir, '..', 'utils'))
sys.path.append(utils_path)

from gnn_model import *
from utils import *

import argparse
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_sparse import SparseTensor
import torch_geometric.transforms as T
from NCN.model import predictor_dict, convdict, GCN_NCN
from functools import partial
from sklearn.metrics import roc_auc_score, average_precision_score
from ogb.linkproppred import PygLinkPropPredDataset
from torch_geometric.utils import negative_sampling, to_undirected
from torch_geometric.datasets import Planetoid

from typing import Iterable
import random
import os
import copy
from scipy.stats import f

    
def set_seed(seed=2020):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # torch.use_deterministic_algorithms(True)

def train(model, predictor, graph, optimizer, aggregation, args, device, maskinput, cnprobs: Iterable[float]=[],
          alpha: float=None): 
    if alpha is not None:
        predictor.setalpha(alpha)
    model.train()
    predictor.train()    
    adj = SparseTensor.from_edge_index(graph.edge_index, graph.edge_weight, [graph.x.size(0), graph.x.size(0)]).to(device)
    neg_edges = graph.train_mask
    
    edge_0, edge_1 = [], []
    for pairs in graph.train_mask:
        for edge in pairs:
            edge_0.append(edge[0].item())
            edge_1.append(edge[1].item())
    edge_index = torch.tensor([edge_0, edge_1], device=device)
    adjmask = torch.ones_like(edge_index[0], dtype=torch.bool)
    
    h = model(graph.x, graph.edge_index)
    
    # print(adj)
    # print(adj.sizes())
    # print(cnprobs)
    edge_predictions = predictor.multidomainforward(h,
                                                    adj,
                                                    edge_index,
                                                    cndropprobs=cnprobs
                                                    )
    
    neg_arc1 = edge_predictions[::2]
    neg_arc2 = edge_predictions[1::2]
    neg_sim = F.cosine_similarity(neg_arc1, neg_arc2, dim=-1)
    loss_sim = torch.clamp(neg_sim, min=0).mean()  # Triplet loss
    loss_sim.backward(retain_graph=True)
  
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
   
    return loss_sim
    
def reliability_check(d, threshold=71.34, q=32):
    """
    Verifica se il modello è affidabile confrontando le differenze tra versioni isomorfe di un grafo.

    Parametri:
    - d: tensore delle differenze con shape (q, d), ottenuto da coppie di grafi isomorfi.
    - threshold: valore di soglia per la statistica T².
    - q: numero di campioni.

    Ritorna:
    - True se la variazione interna è sotto la soglia (affidabile), False altrimenti.
    """
    results = []
    for i in range(d.shape[1]):  # Itera sulle due matrici (indice 1)
        d_i = d[:, i, :]  # Seleziona la i-esima matrice di differenze    
        d_mean = torch.mean(d_i, dim=0)  # Media delle differenze
        S = torch.matmul((d_i - d_mean).T, (d_i - d_mean)) / (q - 1)  # Matrice di covarianza
        # S += torch.eye(S.shape[0], device=S.device)# * 1e-7  # Regularizzazione per evitare problemi di inversione
        S_inv = torch.inverse(S)  # Inversa della matrice di covarianza
        T2_reliability = q * torch.matmul(torch.matmul(d_mean.T, S_inv), d_mean)  # Statistica T^2
        
        results.append(T2_reliability.item() < threshold)  # Verifica il superamento della soglia

    return torch.tensor(results)  # Il test è affidabile se T² è sotto la soglia

def major_procedure(d, threshold=71.34, q=32):
    """
    Valuta se il modello distingue due grafi non isomorfi usando il test T^2.
    
    Parametri:
    - d: tensore di differenze con shape (n, 2, d), dove `n` è il numero di test.
    - threshold: valore di soglia.
    - q: numero di campioni per matrice d.

    Ritorna:
    - Lista di booleani, uno per ciascun test.
    """
    results = []
    for i in range(d.shape[1]):  
        d_i = d[:, i, :] 
        d_mean = torch.mean(d_i, dim=0) 
        S = torch.matmul((d_i - d_mean).T, (d_i - d_mean)) / (q - 1)  # Matrice di covarianza
        # print(S)
        # print(torch.inverse(S))
        # print(torch.linalg.pinv(S))
        # print(torch.eye(S.shape[0], device=S.device).item())
        # print((torch.eye(S.shape[0], device=S.device) * 1e-7).item())
        # S += torch.eye(S.shape[0], device=S.device) * 1e-7  # Regularization per evitare S singolare
    
        S_inv = torch.inverse(S)  # Inversa di S
        T2_test = q * torch.matmul(torch.matmul(d_mean.T, S_inv), d_mean)  # Statistica T^2
        results.append(T2_test.item() > threshold)  # Verifica il superamento della soglia
        # print(T2_test.item(), "\t\t", T2_test.item() > threshold, "\t\t", "mean: ", d_mean.item(), S_inv)
        # print()
    # exit()
    return torch.tensor(results)  # Lista di booleani

@torch.no_grad()
def test(model, predictor, graph, val_test, aggregation, args, device, q=32):
    model.eval()
    differences, differences_isomorfic, losses = [], [], []
    
    for k, single_graph in graph.items():
        
        adj = SparseTensor.from_edge_index(single_graph.edge_index, single_graph.edge_weight, [single_graph.x.size(0), single_graph.x.size(0)]).to(device)
        if val_test == "val": neg_edges = single_graph.val_mask 
        else: neg_edges = single_graph.test_mask
        
        edge_0, edge_1 = [], []
        for pairs in neg_edges:
            for edge in pairs:
                edge_0.append(edge[0].item())
                edge_1.append(edge[1].item())
        edge_index = torch.tensor([edge_0, edge_1], device=device)
        
        h = model(single_graph.x, single_graph.edge_index)

        edge_predictions = predictor.multidomainforward(h,
                                                        adj,
                                                        edge_index,
                                                        cndropprobs=[]
                                                        )
        neg_arc1 = edge_predictions[::2]
        neg_arc2 = edge_predictions[1::2]
        neg_sim = F.cosine_similarity(neg_arc1, neg_arc2, dim=-1)
        diff = neg_arc1 - neg_arc2
        differences.append(diff)
        loss_sim = torch.clamp(neg_sim, min=0).mean()
        losses.append(loss_sim.detach().cpu().numpy())
        if k ==0:
            first = neg_arc1
            
        else:
            differences_isomorfic.append(first-neg_arc1)
            
    differences = torch.stack(differences, dim=0)
    differences = differences[:q]
    result = major_procedure(differences)
    differences_isomorfic=torch.stack(differences_isomorfic, dim=0)
    result_isomorfic = reliability_check(differences_isomorfic)
    accuracy = sum(result & result_isomorfic) / len(result)
    return np.mean(losses), accuracy
    
def parseargs():
    parser = argparse.ArgumentParser(description='OGBL-COLLAB (GNN)')
    parser.add_argument('--use_valedges_as_input', action='store_true')
    parser.add_argument('--mplayers', type=int, default=3)
    parser.add_argument('--aggregation', type=str, default='summation')
    parser.add_argument('--nnlayers', type=int, default=1)
    parser.add_argument('--ln', action="store_true")
    parser.add_argument('--lnnn', action="store_true")
    parser.add_argument('--res', action="store_true")
    parser.add_argument('--jk', action="store_true")
    parser.add_argument('--maskinput', action="store_true")
    parser.add_argument('--hiddim', type=int, default=64)
    parser.add_argument('--gnndp', type=float, default=0.3)
    parser.add_argument('--xdp', type=float, default=0.3)
    parser.add_argument('--tdp', type=float, default=0.3)
    parser.add_argument('--gnnedp', type=float, default=0.3)
    parser.add_argument('--predp', type=float, default=0.3)
    parser.add_argument('--preedp', type=float, default=0.3)
    parser.add_argument('--splitsize', type=int, default=-1)
    parser.add_argument('--gnnlr', type=float, default=0.00003) # 0.0003
    parser.add_argument('--prelr', type=float, default=0.00003) # 0.0003
    parser.add_argument('--batch_size', type=int, default=4096)
    parser.add_argument('--testbs', type=int, default=8192)
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--runs', type=int, default=3)
    parser.add_argument('--probscale', type=float, default=5)
    parser.add_argument('--proboffset', type=float, default=3)
    parser.add_argument('--beta', type=float, default=1)
    parser.add_argument('--alpha', type=float, default=1)
    parser.add_argument('--trndeg', type=int, default=-1)
    parser.add_argument('--tstdeg', type=int, default=-1)
    parser.add_argument('--predictor', type=str, default='cn1')
    parser.add_argument('--patience', type=int, default=300)
    parser.add_argument('--model', choices=convdict.keys())
    parser.add_argument('--cndeg', type=int, default=-1)
    parser.add_argument('--save_gemb', action="store_true")
    parser.add_argument('--load', type=str)
    parser.add_argument('--cnprob', type=float, default=0)
    parser.add_argument('--pt', type=float, default=0.5)
    parser.add_argument("--learnpt", action="store_true")
    parser.add_argument("--use_xlin", action="store_true")
    parser.add_argument("--tailact", action="store_true")
    parser.add_argument("--twolayerlin", action="store_true")
    parser.add_argument("--depth", type=int, default=2)
    parser.add_argument("--increasealpha", action="store_true")
    parser.add_argument("--savex", action="store_true")
    parser.add_argument("--loadx", action="store_true")
    parser.add_argument("--loadmod", action="store_true")
    parser.add_argument("--savemod", action="store_true")

    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--kill_cnt',           dest='kill_cnt',      default=30,    type=int,       help='early stopping')
    parser.add_argument('--seed', type=int, default=999)
    parser.add_argument('--output_dir', type=str, default='output_test')
    parser.add_argument('--save', action='store_true', default=False)
    parser.add_argument('--input_dir', type=str, default=os.path.join(get_root_dir(), "dataset"))
    parser.add_argument('--filename', type=str, default='dataset.pt')
    parser.add_argument('--eval_steps', type=int, default=5)
    parser.add_argument('--l2',		type=float,             default=0.0,			help='L2 Regularization for Optimizer')

    
    args = parser.parse_args()
    return args


def main():
    args = parseargs()
    device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')
    dataset_path = os.path.join(current_dir+'/../data/', args.filename)
    graph = torch.load(dataset_path)
    
    for k, v in graph.items():
        graph[k].edge_index = torch.cat([graph[k].edge_index, graph[k].edge_index.flip(0)], dim=1) 
        graph[k].num_features = graph[k].x.size(0)
        graph[k].to(device)
        
    predfn = predictor_dict[args.predictor]
    if args.predictor != "cn0":
        predfn = partial(predfn, cndeg=args.cndeg)
    if args.predictor in ["cn1", "incn1cn1", "scn1", "catscn1", "sincn1cn1"]:
        predfn = partial(predfn, use_xlin=args.use_xlin, tailact=args.tailact, twolayerlin=args.twolayerlin, beta=args.beta)
    predictor = predfn(args.hiddim, args.hiddim, 1, args.nnlayers,
                           args.predp, args.preedp, args.lnnn).to(device)   
    test_accuracies = []
    for run in range(0, args.runs):
        early_stopping = EarlyStopping(patience=args.patience)
        if args.runs == 1:
            seed = args.seed
        else:
            seed = run
        set_seed(seed)

        args.model = 'gcn'
        model = GCN_NCN(graph[0].num_features, args.hiddim, args.hiddim, 4).to(device)

        optimizer = torch.optim.Adam([{'params': model.parameters(), "lr": args.gnnlr}, 
           {'params': predictor.parameters(), 'lr': args.prelr}], weight_decay=args.l2)

        
        for epoch in range(1, 1 + args.epochs):
            alpha = max(0, min((epoch-5)*0.1, 1)) if args.increasealpha else None
            train_loss = train(model, predictor, graph[0], optimizer, args.aggregation, args, device, args.maskinput, alpha)
            print(f"Epoch: {epoch}, Train Loss: {train_loss.item()}")
            val_loss, val_acc = test(model, predictor, graph, "val", args.aggregation, args, device)
            print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")
            if early_stopping(val_loss):
                best_model_state = model.state_dict()
            
            if early_stopping.early_stop:
                break
            
        if best_model_state is not None:
            model.load_state_dict(best_model_state)
        test_loss, test_acc = test(model, predictor, graph, "test", args.aggregation, args, device)
        print(f"Test Accuracy: {test_acc:.4f}")
        test_accuracies.append(test_acc)
    print(f"Test accuracy over 5 runs: {np.mean(test_accuracies)} ± {np.std(test_accuracies)}")


       

if __name__ == "__main__":
    main()
   