import sys
import os

current_dir = os.path.dirname(__file__)
utils_path = os.path.abspath(os.path.join(current_dir, '..', 'utils'))
sys.path.append(utils_path)

from gnn_model import *
from utils import *

import torch
import numpy as np
import argparse
import scipy.sparse as ssp

import torch.nn.functional as F
import networkx as nx
import copy


from seal_dataset import SEALDataset, SEALDynamicDataset

import torch_geometric.utils as pyg_utils
from collections import deque

from torch_geometric.data import Data, Dataset, InMemoryDataset, DataLoader
import scipy.sparse as ssp
from tqdm import tqdm
from seal_utils import extract_enclosing_subgraphs, k_hop_subgraph, construct_pyg_graph
from torch_sparse import coalesce
from torch_geometric.utils import (negative_sampling, add_self_loops)

                
    
class homo_data(torch.nn.Module):
    def __init__(self, edge_index, num_nodes, x=None, edge_weight=None):
        super().__init__()
        self.edge_index = edge_index
        self.num_nodes = num_nodes

        if x != None: self.x = x
        else: self.x = None


        if edge_weight != None: self.edge_weight = edge_weight
        else: self.edge_weight = None

def get_pos_neg_edges(split, split_edge, edge_index, num_nodes, percent=100):
    pos_edge = split_edge[split]['edge'].t()

    # subsample for pos_edge
    np.random.seed(123)
    num_pos = pos_edge.size(1)
    perm = np.random.permutation(num_pos)
    perm = perm[:int(percent / 100 * num_pos)]
    pos_edge = pos_edge[:, perm]
    
    if split == 'train':
        new_edge_index, _ = add_self_loops(edge_index)
        neg_edge = negative_sampling(
            new_edge_index, num_nodes=num_nodes,
            num_neg_samples=pos_edge.size(1))
        
       
        # subsample for neg_edge
        np.random.seed(123)
        num_neg = neg_edge.size(1)
        perm = np.random.permutation(num_neg)
        perm = perm[:int(percent / 100 * num_neg)]
        neg_edge = neg_edge[:, perm]

    else:
        
        neg_edge = split_edge[split]['edge_neg']

        
        neg_edge = torch.permute(neg_edge[perm], (2, 0, 1))
        neg_edge = neg_edge.view(2,-1)



    return pos_edge, neg_edge


class SEALDataset(InMemoryDataset):
    def __init__(self, root, data, split_edge, num_hops, percent=100, split='train', 
                 use_coalesce=False, node_label='drnl', ratio_per_hop=1.0, 
                 max_nodes_per_hop=None, directed=False):
        self.data = data
        self.split_edge = split_edge
        self.num_hops = num_hops
        self.percent = int(percent) if percent >= 1.0 else percent
        self.split = split
        self.use_coalesce = use_coalesce
        self.node_label = node_label
        self.ratio_per_hop = ratio_per_hop
        self.max_nodes_per_hop = max_nodes_per_hop
        self.directed = directed
        super(SEALDataset, self).__init__(root)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        if self.percent == 100:
            name = 'SEAL_{}_data'.format(self.split)
        else:
            name = 'SEAL_{}_data_{}'.format(self.split, self.percent)
        name += '.pt'
        return [name]

    def process(self):
        pos_edge, neg_edge = get_pos_neg_edges(self.split, self.split_edge, 
                                               self.data.edge_index, 
                                               self.data.num_nodes, 
                                               self.percent)

        if self.use_coalesce:  # compress mutli-edge into edge with weight
            self.data.edge_index, self.data.edge_weight = coalesce(
                self.data.edge_index, self.data.edge_weight, 
                self.data.num_nodes, self.data.num_nodes)

        # if 'edge_weight' in self.data:
        if hasattr(self.data, 'edge_weight')  and self.data.edge_weight != None:
            edge_weight = self.data.edge_weight.view(-1)
        else:
            edge_weight = torch.ones(self.data.edge_index.size(1), dtype=int)
        A = ssp.csr_matrix(
            (edge_weight, (self.data.edge_index[0], self.data.edge_index[1])), 
            shape=(self.data.num_nodes, self.data.num_nodes)
        )

        if self.directed:
            A_csc = A.tocsc()
        else:
            A_csc = None
        
        # Extract enclosing subgraphs for pos and neg edges
        pos_list = extract_enclosing_subgraphs(
            pos_edge, A, self.data.x, 1, self.num_hops, self.node_label, 
            self.ratio_per_hop, self.max_nodes_per_hop, self.directed, A_csc)
        neg_list = extract_enclosing_subgraphs(
            neg_edge, A, self.data.x, 0, self.num_hops, self.node_label, 
            self.ratio_per_hop, self.max_nodes_per_hop, self.directed, A_csc)

        torch.save(self.collate(pos_list + neg_list), self.processed_paths[0])
        del pos_list, neg_list


class SEALDynamicDataset(Dataset):
    def __init__(self, root, data, split_edge, num_hops, percent=100, split='train', 
                 use_coalesce=False, node_label='drnl', ratio_per_hop=1.0, 
                 max_nodes_per_hop=None, directed=False, **kwargs):
        self.data = data
        self.split_edge = split_edge
        self.num_hops = num_hops
        self.percent = percent
        self.use_coalesce = use_coalesce
        self.node_label = node_label
        self.ratio_per_hop = ratio_per_hop
        self.max_nodes_per_hop = max_nodes_per_hop
        self.directed = directed
        super(SEALDynamicDataset, self).__init__(root)

        pos_edge, neg_edge = get_pos_neg_edges(split, self.split_edge, 
                                               self.data.edge_index, 
                                               self.data.num_nodes, 
                                               self.percent)
        self.links = torch.cat([pos_edge, neg_edge], 1).t().tolist()
        self.labels = [1] * pos_edge.size(1) + [0] * neg_edge.size(1)

        if self.use_coalesce:  # compress mutli-edge into edge with weight
            self.data.edge_index, self.data.edge_weight = coalesce(
                self.data.edge_index, self.data.edge_weight, 
                self.data.num_nodes, self.data.num_nodes)

        # if 'edge_weight' in self.data:
        if hasattr(self.data, 'edge_weight') and self.data.edge_weight != None:
            edge_weight = self.data.edge_weight.view(-1)
        else:
            edge_weight = torch.ones(self.data.edge_index.size(1), dtype=int)
        self.A = ssp.csr_matrix(
            (edge_weight, (self.data.edge_index[0], self.data.edge_index[1])), 
            shape=(self.data.num_nodes, self.data.num_nodes)
        )
        if self.directed:
            self.A_csc = self.A.tocsc()
        else:
            self.A_csc = None
        
    def __len__(self):
        return len(self.links)

    def len(self):
        return self.__len__()

    def get(self, idx):
        src, dst = self.links[idx]
        y = self.labels[idx]
        tmp = k_hop_subgraph(src, dst, self.num_hops, self.A, self.ratio_per_hop, 
                             self.max_nodes_per_hop, node_features=self.data.x, 
                             y=y, directed=self.directed, A_csc=self.A_csc)
        data = construct_pyg_graph(*tmp, self.node_label)

        return data
    

        
def train(model, graph, optimizer, aggregation):
    
    x = graph.x
    model.train()

    neg_edges = graph.train_mask 

    h = model(graph.z1, graph.edge_index, x)
    h1 = model(graph.z2, graph.edge_index, x)
    optimizer.zero_grad()
    neg_edge_batch = neg_edges
    
    sorted_indices = torch.minimum(neg_edge_batch[:, 0, 0].long(), neg_edge_batch[:, 0, 1].long())
    larger_indices = torch.maximum(neg_edge_batch[:, 0, 0].long(), neg_edge_batch[:, 0, 1].long())
    if aggregation == "concatenation": neg_arc1 = torch.cat((h[sorted_indices], h[larger_indices]), dim=1)
    elif aggregation == "summation": neg_arc1 = h[sorted_indices] + h[larger_indices]

    sorted_indices = torch.minimum(neg_edge_batch[:, 1, 0].long(), neg_edge_batch[:, 1, 1].long())
    larger_indices = torch.maximum(neg_edge_batch[:, 1, 0].long(), neg_edge_batch[:, 1, 1].long())
    if aggregation == "concatenation": neg_arc2 = torch.cat((h1[sorted_indices], h1[larger_indices]), dim=1)
    elif aggregation == "summation": neg_arc2 = h1[sorted_indices] + h1[larger_indices]
  
    
    neg_sim = F.pairwise_distance(neg_arc1, neg_arc2, p=2) ** 2
    loss_sim = torch.exp(-neg_sim).mean()

    
    loss_sim.backward(retain_graph=True)
    
  
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
   
    return loss_sim

def reliability_check(d, threshold=71.34, q=32):

    results = []
    for i in range(d.shape[1]):  
        d_i = d[:, i, :]    
        d_mean = torch.mean(d_i, dim=0)  
        S = torch.matmul((d_i - d_mean).T, (d_i - d_mean)) / (q - 1) 
        S += torch.eye(S.shape[0], device=S.device) * 1e-7 
        S_inv = torch.inverse(S)  
        T2_reliability = q * torch.matmul(torch.matmul(d_mean.T, S_inv), d_mean) 
        
        results.append(T2_reliability.item() < threshold) 

    return torch.tensor(results)  

def major_procedure(d, threshold=71.34, q=32):
    results = []
    for i in range(d.shape[1]):  
        d_i = d[:, i, :]  
        d_mean = torch.mean(d_i, dim=0)  
        S = torch.matmul((d_i - d_mean).T, (d_i - d_mean)) / (q - 1)  
        S += torch.eye(S.shape[0], device=S.device) * 1e-7 
        S_inv = torch.inverse(S)  
        T2_test = q * torch.matmul(torch.matmul(d_mean.T, S_inv), d_mean)  
        results.append(T2_test.item() > threshold)  
    return torch.tensor(results) 

@torch.no_grad()
def test(model, graph, val_test, aggregation, q=32):
    model.eval()
    differences, differences_isomorfic, losses = [], [], []
    for k, single_graph in graph.items():

        if val_test == "val": neg_edges = single_graph.val_mask 
        else: neg_edges = single_graph.test_mask

        h = model(single_graph.z1, single_graph.edge_index, single_graph.x)
        h1 = model(single_graph.z2, single_graph.edge_index, single_graph.x)
        neg_edge_batch = neg_edges

        sorted_indices = torch.minimum(neg_edge_batch[:, 0, 0].long(), neg_edge_batch[:, 0, 1].long()).clone()
        larger_indices = torch.maximum(neg_edge_batch[:, 0, 0].long(), neg_edge_batch[:, 0, 1].long()).clone()
        if aggregation == "concatenation": neg_arc1 = torch.cat((h[sorted_indices], h[larger_indices]), dim=1)
        elif aggregation == "summation": neg_arc1 = h[sorted_indices] + h[larger_indices]
        
        sorted_indices = torch.minimum(neg_edge_batch[:, 1, 0].long(), neg_edge_batch[:, 1, 1].long()).clone()
        larger_indices = torch.maximum(neg_edge_batch[:, 1, 0].long(), neg_edge_batch[:, 1, 1].long()).clone()
        if aggregation == "concatenation": neg_arc2 = torch.cat((h1[sorted_indices], h1[larger_indices]), dim=1)
        elif aggregation == "summation": neg_arc2 = h1[sorted_indices] + h1[larger_indices]
            
        neg_sim = F.cosine_similarity(neg_arc1, neg_arc2, dim=-1)
        diff = neg_arc1 - neg_arc2
        differences.append(diff)

        loss_sim = torch.clamp(neg_sim, min=0).mean()
        
        losses.append(loss_sim.detach().cpu().numpy())
        if k ==0:
            first = neg_arc1
        else:
            differences_isomorfic.append(first-neg_arc1)
      
        
    differences = torch.stack(differences, dim=0)
    
    differences = differences[:q]
    result = major_procedure(differences)
    
    differences_isomorfic=torch.stack(differences_isomorfic, dim=0)
    result_isomorfic = reliability_check(differences_isomorfic)
    
    accuracy = sum(result & result_isomorfic) / len(result)
    return np.mean(losses), accuracy


def filter_edges_by_group(edge_index, node_groups, target_group):
    mask_src = node_groups[edge_index[0]] == target_group
    mask_dst = node_groups[edge_index[1]] == target_group

    mask = mask_src & mask_dst
    filtered_edge_index = edge_index[:, mask]
    
    return filtered_edge_index

def bfs_pyg(edge_index, num_nodes, start_node):
    """ Esegue una BFS efficiente su un grafo PyTorch Geometric. """
    distances = torch.full((num_nodes,), float('inf'))
    distances[start_node] = 0

    queue = deque([start_node])

    while queue:
        node = queue.popleft()
        d = distances[node].item()
        
        neighbors = edge_index[1][edge_index[0] == node]
        
        for neighbor in neighbors:
            if distances[neighbor] == float('inf'):  
                distances[neighbor] = d + 1
                queue.append(neighbor.item())

    return distances

def compute_drnl_labels(data, target_edges):
    edge_index = data.edge_index
    num_nodes = data.num_nodes

    INF = 10**6
    labels = torch.full((num_nodes,), INF, dtype=torch.long)

    bfs_results = {}
    for u, v in target_edges:
        filtered_edge_index = filter_edges_by_group(edge_index, data.batch, data.batch[u])
        if u not in bfs_results:
            bfs_results[u] = bfs_pyg(filtered_edge_index, num_nodes, u)
        if v not in bfs_results:
            bfs_results[v] = bfs_pyg(filtered_edge_index, num_nodes, v)
    for u, v in target_edges:
        filtered_edge_index = filter_edges_by_group(edge_index, data.batch, data.batch[u])
        dist_u = bfs_results[u]
        dist_v = bfs_results[v]
        for node in torch.unique(filtered_edge_index).tolist():
            if node == u or node == v:
                labels[node] = 1
            else:
                d1 = dist_u[node]
                d2 = dist_v[node]
                labels[node] = 1 + min(d1, d2)

    return labels




def compute_in_parallel(graph, device):
    for k, v in graph.items():
        _, graph[k] = process_graph_item(k, v, graph)
        
    return graph

def process_graph_item(k, v, graph):
    updated_graph = graph[k]
    
    all_edges1, all_edges2 = [], []
    for mask in [updated_graph.train_mask, updated_graph.val_mask, updated_graph.test_mask]:
        for m in mask:
            all_edges1.append(m[0].tolist())
            all_edges2.append(m[1].tolist())
    
    updated_graph.z1 = compute_drnl_labels(v, all_edges1)
    updated_graph.z2 = compute_drnl_labels(v, all_edges2)
    return k, updated_graph
    
def main():
    parser = argparse.ArgumentParser(description='homo')
    parser.add_argument('--data_name', type=str, default='cora')
    parser.add_argument('--neg_mode', type=str, default='equal')
    parser.add_argument('--gnn_model', type=str, default='DGCNN')
    parser.add_argument('--score_model', type=str, default='mlp_score')
    parser.add_argument('--aggregation', type=str, default='concatenation')
    ##gnn setting
    parser.add_argument('--num_layers', type=int, default=3)
    parser.add_argument('--num_layers_predictor', type=int, default=2)
    parser.add_argument('--hidden_channels', type=int, default=2)
    parser.add_argument('--dropout', type=float, default=0.0)


    ### train setting
    parser.add_argument('--batch_size', type=int, default=1024)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--eval_steps', type=int, default=5)
    parser.add_argument('--runs', type=int, default=5)
    parser.add_argument('--kill_cnt',           dest='kill_cnt',      default=10,    type=int,       help='early stopping')
    parser.add_argument('--output_dir', type=str, default='output_test')
    parser.add_argument('--input_dir', type=str, default=os.path.join(get_root_dir(), "dataset"))
    parser.add_argument('--filename', type=str, default='dataset.pt')
    parser.add_argument('--l2',		type=float,             default=0.0,			help='L2 Regularization for Optimizer')
    parser.add_argument('--seed', type=int, default=999)
    
    parser.add_argument('--save', action='store_true', default=False)
    parser.add_argument('--use_saved_model', action='store_true', default=False)
    parser.add_argument('--metric', type=str, default='MRR')
    parser.add_argument('--device', type=int, default=1)
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--num_workers', type=int, default=4, 
                    help="number of workers for dynamic mode; 0 if not dynamic")
    
    ####### gin
    parser.add_argument('--gin_mlp_layer', type=int, default=2)

    ######gat
    parser.add_argument('--gat_head', type=int, default=1)

    ######mf
    parser.add_argument('--cat_node_feat_mf', default=False, action='store_true')

    ####### seal 
    parser.add_argument('--patience', type=int, default=10)
    parser.add_argument('--dynamic_train', action='store_true', 
                    help="dynamically extract enclosing subgraphs on the fly")
    parser.add_argument('--dynamic_val', action='store_true')
    parser.add_argument('--dynamic_test', action='store_true')
    parser.add_argument('--train_percent', type=float, default=100)
    parser.add_argument('--val_percent', type=float, default=100)
    parser.add_argument('--test_percent', type=float, default=100)
    
    parser.add_argument('--node_label', type=str, default='drnl',  help="which specific labeling trick to use")
    parser.add_argument('--ratio_per_hop', type=float, default=1.0)
    parser.add_argument('--max_nodes_per_hop', type=int, default=None)
    parser.add_argument('--num_hops', type=int, default=3)
    parser.add_argument('--save_appendix', type=str, default='', 
                    help="an appendix to the save directory")
    parser.add_argument('--data_appendix', type=str, default='', 
                    help="an appendix to the data directory")
    parser.add_argument('--use_feature', action='store_true', 
                    help="whether to use raw node features as GNN input")
    parser.add_argument('--train_node_embedding', action='store_true', 
                    help="also train free-parameter node embeddings together with GNN")
    parser.add_argument('--sortpool_k', type=float, default=0.6)
    parser.add_argument('--use_edge_weight', action='store_true', 
                    help="whether to consider edge weight in GNN")
    parser.add_argument('--eval_mrr_data_name', type=str, default='ogbl-citation2')
    parser.add_argument('--test_bs', type=int, default=1024)


    args = parser.parse_args()
    print(args)

    init_seed(args.seed)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)
    print(device)

    if not args.dynamic_train and not args.dynamic_val and not args.dynamic_test:
        args.num_workers = 0
    
    dataset_path = os.path.join(current_dir+'/../data/', args.filename)
    graph = torch.load(dataset_path)

    for k, v in graph.items():
        graph[k].edge_index = torch.cat([graph[k].edge_index, graph[k].edge_index.flip(0)], dim=1) 
    graph = compute_in_parallel(graph, torch.device('cpu'))
    for k, v in graph.items():
        graph[k].to(device)
    


    max_z = 1000 
    test_accuracies = []
    for run in range(args.runs):
        early_stopping = EarlyStopping(patience=args.patience)
        best_model_state = None  
        print('#################################          ', run, '          #################################')
        
        if args.runs == 1:
            seed = args.seed
        else:
            seed = run
        print('seed: ', seed)
        init_seed(seed)
        
        train_dataset=graph[0]
        model = DGCNN(args.hidden_channels, args.num_layers, max_z, args.sortpool_k, 
                      train_dataset, args.dynamic_train, use_feature=args.use_feature, 
                      node_embedding=None).to(device)
        
        parameters = list(model.parameters())
        optimizer = torch.optim.Adam(params=parameters, lr=args.lr, weight_decay=args.l2)
        
        for epoch in range(args.epochs):

            train_loss = train(model, graph[0], optimizer, args.aggregation)
            print(f"Epoch {epoch+1}/{args.epochs}, Train Loss: {train_loss:.4f}")
            
            val_loss, val_acc = test(model, graph, "val", args.aggregation)
            print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")
            if early_stopping(val_loss):
                best_model_state = model.state_dict()
            
            if early_stopping.early_stop:
                break

        if best_model_state is not None:
            model.load_state_dict(best_model_state)
            
        test_loss, test_acc = test(model, graph, "test", args.aggregation)
        print(f"Test Accuracy: {test_acc}")
        test_accuracies.append(test_acc)
    print(f"Test accuracy over 5 runs: {np.mean(test_accuracies)} ± {np.std(test_accuracies)}")

        

if __name__ == "__main__":

    main()

    