
import sys
import os

current_dir = os.path.dirname(__file__)
utils_path = os.path.abspath(os.path.join(current_dir, '..', 'utils'))
sys.path.append(utils_path)

from gnn_model import *
from utils import *

import argparse
import time
from math import inf
import sys

import numpy as np
from torch_geometric.data import Data
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import  to_undirected
from BUDDY.data import get_loaders_hard_neg
from BUDDY.utils import select_embedding, select_model, get_num_samples, get_split_samples, str2bool

from torch.utils.data import DataLoader


def get_edge_label(pos_edges, neg_edges):

    n_pos, n_neg = pos_edges.shape[0], neg_edges.shape[0]
    edge_label = torch.cat([torch.ones(n_pos), torch.zeros(n_neg)], dim=0)
    edge_label_index = torch.cat([pos_edges, neg_edges], dim=0).t()

    return edge_label, edge_label_index

def reliability_check(d, threshold=71.34, q=32):
    results = []
    for i in range(d.shape[1]):  
        d_i = d[:, i, :]    
        d_mean = torch.mean(d_i, dim=0)  
        S = torch.matmul((d_i - d_mean).T, (d_i - d_mean)) / (q - 1)  
        S += torch.eye(S.shape[0], device=S.device) * 1e-7  
        S_inv = torch.inverse(S)  
        T2_reliability = q * torch.matmul(torch.matmul(d_mean.T, S_inv), d_mean) 
        
        results.append(T2_reliability.item() < threshold) 

    return torch.tensor(results) 

def major_procedure(d, threshold=71.34, q=32):
    results = []
    for i in range(d.shape[1]):  
        d_i = d[:, i, :]  
        d_mean = torch.mean(d_i, dim=0)
        S = torch.matmul((d_i - d_mean).T, (d_i - d_mean)) / (q - 1)  
        S += torch.eye(S.shape[0], device=S.device) * 1e-7  
        S_inv = torch.inverse(S)  
        T2_test = q * torch.matmul(torch.matmul(d_mean.T, S_inv), d_mean)  
        results.append(T2_test.item() > threshold) 
    return torch.tensor(results) 

def train_elph(model, optimizer, train_loader, args, device):   
    model.train()
    data = train_loader.dataset
    print(train_loader)
    links = data.links
    labels = torch.tensor(data.labels)
    # sampling
    train_samples = get_num_samples(args.train_samples, len(labels))
    sample_indices = torch.randperm(len(labels))[:train_samples]
    links = links[sample_indices]
    labels = labels[sample_indices]

   
    batch_processing_times = []
    loader = DataLoader(range(len(links)), args.batch_size, shuffle=True)
    for batch_count, indices in enumerate(loader):
        # do node level things
        if model.node_embedding is not None:
            if args.propagate_embeddings:
                emb = model.propagate_embeddings_func(data.edge_index.to(device))
            else:
                emb = model.node_embedding.weight
        else:
            emb = None
            
        # get node features
        node_features, hashes, cards = model(data.x.to(device), data.edge_index.to(device))
        curr_links = links[indices].to(device)
        batch_node_features = None if node_features is None else node_features[curr_links]
        batch_emb = None if emb is None else emb[curr_links].to(device)
        # hydrate link features
        # print(f"curr links {curr_links.shape}")
        # print(f"hashes {hashes}")
        # print(f"cards {cards.shape}")
        if args.use_struct_feature:
            subgraph_features = model.elph_hashes.get_subgraph_features(curr_links, hashes, cards).to(device)
        else:  # todo fix this
            subgraph_features = torch.zeros(data.subgraph_features[indices].shape).to(device)
        start_time = time.time()
        optimizer.zero_grad()
        logits = model.predictor(subgraph_features, batch_node_features, batch_emb)
        neg_arc1 = logits[::2]
        neg_arc2 = logits[1::2]
        neg_sim = F.cosine_similarity(neg_arc1, neg_arc2, dim=-1)
        loss_sim = torch.clamp(neg_sim, min=0).mean()  # Triplet loss
        loss_sim.backward(retain_graph=True)
        optimizer.step()
    return loss_sim
   
def train_buddy(model, optimizer, train_loader, args, device, emb=None):
    # print('starting training')
    t0 = time.time()
    model.train()
    total_loss = 0
    data = train_loader.dataset
    # hydrate edges
    links = data.links
    labels = torch.tensor(data.labels)
    # sampling
    train_samples = get_num_samples(args.train_samples, len(labels))
    sample_indices = torch.randperm(len(labels))[:train_samples]
    links = links[sample_indices]
    labels = labels[sample_indices]

    
    batch_processing_times = []
    loader = DataLoader(range(len(links)), args.batch_size, shuffle=True)
    for batch_count, indices in enumerate(loader):
        # do node level things
        if model.node_embedding is not None:
            if args.propagate_embeddings:
                emb = model.propagate_embeddings_func(data.edge_index.to(device))
            else:
                emb = model.node_embedding.weight
        else:
            emb = None
        curr_links = links[indices]
        batch_emb = None if emb is None else emb[curr_links].to(device)

        if args.use_struct_feature:
           
            sf_indices = sample_indices[indices]  # need the original link indices as these correspond to sf
            # print(data.subgraph_features)
            subgraph_features = data.subgraph_features[sf_indices].to(device)
            
               
        else:
            subgraph_features = torch.zeros(data.subgraph_features[indices].shape).to(device)
        node_features = data.x[curr_links].to(device)
        degrees = data.degrees[curr_links].to(device)
        if args.use_RA:
            ra_indices = sample_indices[indices]
            RA = data.RA[ra_indices].to(device)
        else:
            RA = None
        start_time = time.time()
        optimizer.zero_grad()
        logits = model(subgraph_features, node_features, degrees[:, 0], degrees[:, 1], RA, batch_emb)
    neg_arc1 = logits[::2]
    neg_arc2 = logits[1::2]
    # neg_sim = F.cosine_similarity(neg_arc1, neg_arc2, dim=-1)
    # loss_sim = torch.clamp(neg_sim, min=0).mean()  # Triplet loss
    
    neg_sim = F.pairwise_distance(neg_arc1, neg_arc2, p=2) ** 2
    loss_sim = torch.exp(-neg_sim).mean()
    loss_sim.backward(retain_graph=True)
    optimizer.step()
    return loss_sim

@torch.no_grad()
def test_buddy(model, loaders, device, args, val_test, split=None):
    model.eval()
    differences, differences_isomorfic, losses = [], [], []
    for k, v in loaders.items(): 
        if val_test == "val": data = v['val_loader'].dataset
        else: data = v['test_loader'].dataset

        links = data.links
        loader = DataLoader(range(len(links)), args.eval_batch_size,shuffle=False) 
        if model.node_embedding is not None:
            if args.propagate_embeddings:
                emb = model.propagate_embeddings_func(data.edge_index.to(device))
            else:
                emb = model.node_embedding.weight
        else:
            emb = None
        for batch_count, indices in enumerate(loader):
            curr_links = links[indices]
            batch_emb = None if emb is None else emb[curr_links].to(device)
            if args.use_struct_feature:
                subgraph_features = data.subgraph_features[indices].to(device)
            else:
                subgraph_features = torch.zeros(data.subgraph_features[indices].shape).to(device)
            node_features = data.x[curr_links].to(device)
            degrees = data.degrees[curr_links].to(device)
            if args.use_RA:
                RA = data.RA[indices].to(device)
            else:
                RA = None
            logits = model(subgraph_features, node_features, degrees[:, 0], degrees[:, 1], RA, batch_emb)
            neg_arc1 = logits[::2]
            neg_arc2 = logits[1::2]
            diff = neg_arc1 - neg_arc2
            differences.append(diff)
            neg_sim = F.pairwise_distance(neg_arc1, neg_arc2, p=2) ** 2
            loss_sim = torch.exp(-neg_sim).mean()
            losses.append(loss_sim.detach().cpu().numpy())
            if k ==0:
                first = neg_arc1
            else:
                differences_isomorfic.append(first-neg_arc1)
    differences = torch.stack(differences, dim=0)
    differences = differences[:32]
    result = major_procedure(differences)

    differences_isomorfic=torch.stack(differences_isomorfic, dim=0)
    result_isomorfic = reliability_check(differences_isomorfic)
    accuracy = sum(result & result_isomorfic) / len(result)
    return np.mean(losses), accuracy


@torch.no_grad()
def test_edge_elph(model, loaders, device, args, val_test, split=None):
    model.eval()
    differences, differences_isomorfic, losses = [], [], []
    for k, v in loaders.items():    
        if val_test == "val": data = v['val_loader'].dataset
        else: data = v['test_loader'].dataset
            
        links = data.links
        loader = DataLoader(range(len(links)), args.eval_batch_size,shuffle=False) 

        if model.node_embedding is not None:
            if args.propagate_embeddings:
                emb = model.propagate_embeddings_func(data.edge_index.to(device))
            else:
                emb = model.node_embedding.weight
        else:
            emb = None
        node_features, hashes, cards = model(data.x.to(device), data.edge_index.to(device))
        for batch_count, indices in enumerate((loader)):
            curr_links = links[indices].to(device)
            batch_emb = None if emb is None else emb[curr_links].to(device)
            if args.use_struct_feature:
                subgraph_features = model.elph_hashes.get_subgraph_features(curr_links, hashes, cards).to(device)
            else:
                subgraph_features = torch.zeros(data.subgraph_features[indices].shape).to(device)
            batch_node_features = None if node_features is None else node_features[curr_links]
            logits = model.predictor(subgraph_features, batch_node_features, batch_emb)
            neg_arc1 = logits[::2]
            neg_arc2 = logits[1::2]
            diff = neg_arc1 - neg_arc2
            differences.append(diff)
            neg_sim = F.pairwise_distance(neg_arc1, neg_arc2, p=2) ** 2
            loss_sim = torch.exp(-neg_sim).mean()
            losses.append(loss_sim.detach().cpu().numpy())
            if k ==0:
                first = neg_arc1
            else:
                differences_isomorfic.append(first-neg_arc1)
    differences = torch.stack(differences, dim=0)
    differences = differences[:32]
    result = major_procedure(differences)

    differences_isomorfic=torch.stack(differences_isomorfic, dim=0)
    result_isomorfic = reliability_check(differences_isomorfic)
    accuracy = sum(result & result_isomorfic) / len(result)
    return np.mean(losses), accuracy

def get_test_func(model_str):
    if model_str == 'ELPH':
        return test_edge_elph
    elif model_str == 'BUDDY':
        return test_edge
    
def main():
    parser = argparse.ArgumentParser(description='homo')
    parser.add_argument('--data_name', type=str, default='cora')
    
    ##gnn setting
    
    parser.add_argument('--hidden_channels', type=int, default=64)
    

    ### train setting
    parser.add_argument('--batch_size', type=int, default=30000)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--epochs', type=int, default=500)
    parser.add_argument('--eval_steps', type=int, default=5)
    parser.add_argument('--runs', type=int, default=5)
    parser.add_argument('--kill_cnt',           dest='kill_cnt',      default=10,    type=int,       help='early stopping')
    parser.add_argument('--output_dir', type=str, default='output_test')
    parser.add_argument('--input_dir', type=str, default=os.path.join(get_root_dir(), "dataset"))
    parser.add_argument('--filename', type=str, default='dataset.pt')
    parser.add_argument('--l2',		type=float,             default=0.0,			help='L2 Regularization for Optimizer')
    parser.add_argument('--seed', type=int, default=999)
    
    parser.add_argument('--save', action='store_true', default=False)
    parser.add_argument('--use_saved_model', action='store_true', default=False)
    parser.add_argument('--patience', type=int, default=10)
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)

    ##
    parser.add_argument('--model', type=str, default='BUDDY')
    parser.add_argument('--max_hash_hops', type=int, default=3, help='the maximum number of hops to hash')
    parser.add_argument('--floor_sf', type=str2bool, default=0,
                        help='the subgraph features represent counts, so should not be negative. If --floor_sf the min is set to 0')
    parser.add_argument('--minhash_num_perm', type=int, default=128, help='the number of minhash perms')
    parser.add_argument('--hll_p', type=int, default=8, help='the hyperloglog p parameter')
    parser.add_argument('--use_zero_one', type=str2bool,
                        help="whether to use the counts of (0,1) and (1,0) neighbors")
    parser.add_argument('--load_features', action='store_true', help='load node features from disk')
    parser.add_argument('--load_hashes', action='store_true', help='load hashes from disk')
    parser.add_argument('--cache_subgraph_features', action='store_true',
                        help='write / read subgraph features from disk')
    parser.add_argument('--use_feature', type=str2bool, default=True,
                        help="whether to use raw node features as GNN input")
    parser.add_argument('--use_RA', type=str2bool, default=False, help='whether to add resource allocation features')
    parser.add_argument('--sign_k', type=int, default=0)
    parser.add_argument('--num_negs', type=int, default=1, help='number of negatives for each positive')
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--train_node_embedding', action='store_true',
                        help="also train free-parameter node embeddings together with GNN")
    parser.add_argument('--pretrained_node_embedding', type=str, default=None,
                        help="load pretrained node embeddings as additional node features")
    parser.add_argument('--label_dropout', type=float, default=0.5)
    parser.add_argument('--feature_dropout', type=float, default=0.5)
    parser.add_argument('--propagate_embeddings', action='store_true',
                        help='propagate the node embeddings using the GCN diffusion operator')
    parser.add_argument('--add_normed_features', dest='add_normed_features', type=str2bool,
                        help='Adds a set of features that are normalsied by sqrt(d_i*d_j) to calculate cosine sim')
    parser.add_argument('--train_samples', type=float, default=inf, help='the number of training edges or % if < 1')
    parser.add_argument('--use_struct_feature', type=str2bool, default=True,
                        help="whether to use structural graph features as GNN input")
    parser.add_argument('--loss', default='bce', type=str, help='bce or auc')

    parser.add_argument('--dynamic_train', action='store_true',
                        help="dynamically extract enclosing subgraphs on the fly")
    parser.add_argument('--dynamic_val', action='store_true')
    parser.add_argument('--dynamic_test', action='store_true')
    parser.add_argument('--eval_batch_size', type=int, default=1024*64,
                        help='eval batch size should be largest the GPU memory can take - the same is not necessarily true at training time')
    
    parser.add_argument('--no_sf_elph', action='store_true',
                        help='use the structural feature in elph or not')
    parser.add_argument('--feature_prop', type=str, default='gcn',
                        help='how to propagate ELPH node features. Values are gcn, residual (resGCN) or cat (jumping knowledge networks)')
    
    parser.add_argument('--eval_mrr_data_name', type=str, default='ogbl-citation2')
    parser.add_argument('--test_batch_size', type=int, default=4096)

    args = parser.parse_args()
   
    print(args)
    init_seed(args.seed)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)
    
    dataset_path = os.path.join(current_dir+'/../data/', args.filename)
    graph = torch.load(dataset_path)
    
    for k, v in graph.items():
        graph[k].edge_index = torch.cat([graph[k].edge_index, graph[k].edge_index.flip(0)], dim=1) 
        graph[k].num_features = graph[k].x.size(0)
        graph[k]
    
    root_split = {}
    for k, v in graph.items():
        splits = {}
        for key in ['train', 'valid', 'test']:
            if key == 'train': mask=v.train_mask
            elif key == 'valid': mask=v.val_mask
            elif key == 'test': mask=v.test_mask
            edge_0, edge_1 = [], []
            for pairs in mask:
                for edge in pairs:
                    edge_0.append(edge[0].item())
                    edge_1.append(edge[1].item())
            edge_label_index = torch.tensor([edge_0, edge_1], device=device)
            splits[key] = Data(x=v.x, edge_index=v.edge_index, edge_label_index=edge_label_index, edge_label= torch.ones(edge_label_index.shape[1], dtype=torch.float32))
        root_split[k] = {}
        root_split[k]['train_loader'], root_split[k]['train_eval_loader'], root_split[k]['val_loader'], root_split[k]['test_loader'] = get_loaders_hard_neg(args, splits)

    test_accuracies, val_accuracies = [], []
    for run in range(args.runs):
        early_stopping = EarlyStopping(patience=args.patience)
        best_model_state = None  
        print('#################################          ', run, '          #################################')
        if args.runs == 1: seed = args.seed
        else: seed = run
        print('seed: ', seed)
        init_seed(seed)
       
        emb = select_embedding(args, graph[0].x.size(0), device)
        model, optimizer = select_model(args, emb, device)

        
        for epoch in range(1, 1 + args.epochs):
            if args.model == 'BUDDY':
                train_loss = train_buddy(model, optimizer, root_split[0]['train_loader'], args, device)
                val_loss, val_acc = test_buddy(model, root_split, device, args, 'val')
                test_loss, test_acc = test_buddy(model, root_split, device, args, 'test')
            else:
                train_loss = train_elph(model, optimizer, root_split[0]['train_loader'], args, device)
                val_loss, val_acc = test_edge_elph(model, root_split, device, args, 'val')
                test_loss, test_acc = test_edge_elph(model, root_split, device, args, 'test')
            print(f"Epoch: {epoch}, train loss: {train_loss.item()}")
            print(f"Validation loss: {val_loss}, Validation Accuracy: {val_acc:.4f}")
            if early_stopping(val_loss):
                best_model_state = model.state_dict()
            
            if early_stopping.early_stop:
                break

        if best_model_state is not None:
            model.load_state_dict(best_model_state)
              

        if args.model == 'BUDDY':
                test_loss, test_acc = test_buddy(model, root_split, device, args, 'test')
        else:
                test_loss, test_acc = test_edge_elph(model, root_split, device, args, 'test')
        print(f"Test accuracy: {test_acc:.4f}")
        test_accuracies.append(test_acc)
    print(f"Test accuracy over 5 runs: {np.mean(test_accuracies)} ± {np.std(test_accuracies)}")
    
if __name__ == "__main__":
    
    main()

   