import torch
import torch.nn.functional as F
from torch_geometric.utils import to_undirected, add_self_loops
from torch_sparse import SparseTensor
import numpy as np
from sklearn.metrics import roc_auc_score, f1_score
import networkx as nx
from os import path
# import pymetis
import metis


def load_splits(data_dir, name):
    splits_lst = []
    if name in ['cora', 'citeseer', 'pubmed', 'amazon-computer', 'amazon-photo', 'coauthor-cs', 'coauthor-physics', 'squirrel', 'chameleon', 'minesweeper']:
        split = np.load(f'splits/{name}-split.npy', allow_pickle=True)
        for i in range(split.shape[0]):
            splits = {}
            splits['train'] = torch.from_numpy(np.asarray(split[i]['train']))
            splits['valid'] = torch.from_numpy(np.asarray(split[i]['valid']))
            splits['test'] = torch.from_numpy(np.asarray(split[i]['test']))
            splits_lst.append(splits)
    else:
        raise NotImplementedError

    return splits_lst


def calculate_cluster_edge_index(N, edge_index, clusters, node_feat):
    if N < clusters:
        membership = torch.randperm(clusters)
    else:
        adjlist = edge_index.t()
        G = nx.Graph()
        G.add_nodes_from(np.arange(N))
        G.add_edges_from(adjlist.tolist())

        _, membership = metis.part_graph(G, clusters, recursive=True)

    # Create edges from cluster nodes to their member nodes
    node2cluster = torch.stack([
        torch.tensor([N + membership[i]])
        for i in range(N)
    ]).squeeze()
    node2cluster = torch.stack([node2cluster, torch.arange(N)])
    
    membership = torch.tensor(membership)
    cluster_feat = [node_feat[torch.where(membership == i)[0]].mean(dim=0) for i in range(clusters)]
    cluster_feat = torch.stack(cluster_feat, dim=0)

    return node2cluster, membership, cluster_feat

def calculate_global_edge_index(N, num_clusters, global_nodes, global_nodes_per_class, label, node_feat, train_idx):
    global2node = [torch.arange(N), torch.arange(
        N + num_clusters, N + num_clusters + global_nodes)]
    
    train_idx = train_idx
    label_idx = [train_idx[torch.where(label[train_idx] == i)[0]] for i in range(label.max() + 1)]
    global_feat = [node_feat[li].mean(dim=0) for li in label_idx]
    global_feat = torch.stack(global_feat, dim=0)
    # print(label_idx)
    node2global = [[],[]]
    for i, li in enumerate(label_idx):
        li = li.tolist()
        for j in range(global_nodes_per_class):
            node2global[0] += [N + num_clusters + global_nodes_per_class*i + j] * len(li)
        node2global[1] += li * global_nodes_per_class
    # print(node2global[0], node2global[1])
    node2global = torch.tensor(node2global)
    return node2global, global2node, global_feat

def modify_globalMask(args, dataset, split_idx, run):
    N = dataset.graph["num_nodes"]
    num_classes = max(dataset.label.max().item() +
                          1, dataset.label.shape[1])
    global_nodes = num_classes * args.global_nodes_per_class

    node2global, global2node, global_feat = calculate_global_edge_index(
        N, args.num_clusters, global_nodes, args.global_nodes_per_class, dataset.label.cpu(), dataset.graph["node_feat"].cpu(), split_idx['train'])
        
    node2global = node2global.to(dataset.graph["edge_index"].device)
    global2node = [t.to(dataset.graph["edge_index"].device) for t in global2node]
    global2node.append(None)
    
    dataset.graph["edge_masks"][2] = [node2global, global2node]
    
    dataset.graph["global_feat"] = global_feat.to(dataset.graph["edge_index"].device)
    return
    

def augment_graph(args, dataset):
    if path.exists(f'AugEdges/{args.dataset}_{args.num_clusters}_{args.global_nodes_per_class}.pt'):
        N = dataset.graph["num_nodes"]
        edge_index = dataset.graph["edge_index"]
        data = torch.load(
            f'AugEdges/{args.dataset}_{args.num_clusters}_{args.global_nodes_per_class}.pt')
        num_classes = max(dataset.label.max().item() +
                          1, dataset.label.shape[1])
        global_nodes = num_classes * args.global_nodes_per_class
        dataset.graph["tot_nodes"] = dataset.graph["num_nodes"] + \
            args.num_clusters + global_nodes
        
        node2cluster = data["node2cluster"].to(edge_index.device)
        membership = data["membership"]
        if args.split == "fixed":
            node2global = data["node2global"].to(edge_index.device)
            global2node = [t.to(edge_index.device) for t in data["global2node"]]
            global2node.append(None)
            global_feat = data["global_feat"].to(edge_index.device)
        else:
            node2global, global2node, global_feat = None, None, None
        
        if args.learn_global:
            global_feat = None
        
        node2cluster = to_undirected(node2cluster,num_nodes=N+args.num_clusters)
        node2cluster = add_self_loops(node2cluster,num_nodes=N+args.num_clusters)[0]
        
        dataset.graph["edge_masks"] = [[edge_index], [node2cluster], [node2global, global2node]]
        dataset.graph["membership"] = membership
        
        dataset.label_global = torch.tensor([i for i in range(num_classes) for _ in range(
            args.global_nodes_per_class)]).to(edge_index.device)
        cluster_feat = data["cluster_feat"].to(edge_index.device)
        dataset.graph["node_feat"] = torch.cat([dataset.graph["node_feat"], cluster_feat], dim=0)
        dataset.graph["global_feat"] = global_feat
        
    else:
        N = dataset.graph["num_nodes"]
        edge_index = dataset.graph["edge_index"]
        num_classes = max(dataset.label.max().item() +
                          1, dataset.label.shape[1])
        global_nodes = num_classes * args.global_nodes_per_class
        dataset.graph["tot_nodes"] = N + args.num_clusters + global_nodes

        node2cluster, membership, cluster_feat = calculate_cluster_edge_index(
            N, edge_index, args.num_clusters, dataset.graph["node_feat"].cpu())

        if args.split == "fixed":
            node2global, global2node, global_feat = calculate_global_edge_index(
                N, args.num_clusters, global_nodes, args.global_nodes_per_class, dataset.label.cpu(), dataset.graph["node_feat"].cpu(), dataset.split['train'])
            global_feat = global_feat.to(edge_index.device)
        else:
            node2global, global2node, global_feat = None, None, None
            
        if args.learn_global:
            global_feat = None

        torch.save({"node2cluster": node2cluster, "membership": membership, "node2global": node2global,
                   "global2node": global2node, "cluster_feat": cluster_feat, "global_feat": global_feat}, f'AugEdges/{args.dataset}_{args.num_clusters}_{args.global_nodes_per_class}.pt')

        node2cluster = node2cluster.to(edge_index.device)

        if args.split == "fixed":
            node2global = node2global.to(edge_index.device)
            global2node = [t.to(edge_index.device) for t in global2node]
            global2node.append(None)
            
        
        node2cluster = to_undirected(node2cluster,num_nodes=N+args.num_clusters)
        node2cluster = add_self_loops(node2cluster,num_nodes=N+args.num_clusters)[0]
        
        dataset.graph["edge_masks"] = [[edge_index], [node2cluster], [node2global, global2node]]
        dataset.graph["membership"] = membership
        dataset.label_global = torch.tensor([i for i in range(num_classes) for _ in range(
            args.global_nodes_per_class)]).unsqueeze(-1).to(edge_index.device)
        cluster_feat = cluster_feat.to(edge_index.device)
        dataset.graph["node_feat"] = torch.cat([dataset.graph["node_feat"], cluster_feat], dim=0)
        dataset.graph["global_feat"] = global_feat
        

    return
    

def eval_f1(y_true, y_pred):
    acc_list = []
    y_true = y_true.detach().cpu().numpy()
    y_pred = y_pred.argmax(dim=-1, keepdim=True).detach().cpu().numpy()

    for i in range(y_true.shape[1]):
        f1 = f1_score(y_true, y_pred, average='micro')
        acc_list.append(f1)

    return sum(acc_list)/len(acc_list)


def eval_acc(y_true, y_pred):
    # acc_list = []
    # y_true = y_true.detach().cpu().numpy()
    # y_pred = y_pred.argmax(dim=-1, keepdim=True).detach().cpu().numpy()
    # for i in range(y_true.shape[1]):
    #     is_labeled = y_true[:, i] == y_true[:, i]
    #     correct = y_true[is_labeled, i] == y_pred[is_labeled, i]
    #     acc_list.append(float(np.sum(correct))/len(correct))

    # return sum(acc_list)/len(acc_list)
    y_pred = y_pred.argmax(dim=-1, keepdim=True)
    return (y_true == y_pred).float().mean()


def eval_rocauc(y_true, y_pred):
    """ adapted from ogb
    https://github.com/snap-stanford/ogb/blob/master/ogb/nodeproppred/evaluate.py"""
    rocauc_list = []
    y_true = y_true.detach().cpu().numpy()
    if y_true.shape[1] == 1:
        # use the predicted class for single-class classification
        y_pred = F.softmax(y_pred, dim=-1)[:, 1].unsqueeze(1).cpu().numpy()
    else:
        y_pred = y_pred.detach().cpu().numpy()
    
    y_pred = np.nan_to_num(y_pred, 0)

    for i in range(y_true.shape[1]):
        # AUC is only defined when there is at least one positive data.
        if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0:
            is_labeled = y_true[:, i] == y_true[:, i]
            score = roc_auc_score(y_true[is_labeled, i], y_pred[is_labeled, i])

            rocauc_list.append(score)

    if len(rocauc_list) == 0:
        raise RuntimeError(
            'No positively labeled data available. Cannot compute ROC-AUC.')

    return sum(rocauc_list)/len(rocauc_list)


dataset_drive_url = {
    'snap-patents': '1ldh23TSY1PwXia6dU0MYcpyEgX-w3Hia',
    'pokec': '1dNs5E7BrWJbgcHeQ_zuy5Ozp2tRCWG0y',
    'yelp-chi': '1fAXtTVQS4CfEk4asqrFw9EPmlUPGbGtJ',
}

def calculate_norm_A(edge_index, val=None):
    N = edge_index.max().item() + 1
    if val == None:
        val = torch.ones(edge_index.shape[1]).to(edge_index.device)

    A = SparseTensor.from_edge_index(
        edge_index=edge_index, edge_attr=val, sparse_sizes=(N, N))
    deg = A.sum(1)
    deg_inv_sqrt = deg.pow(-1).view(-1, 1).pow(0.5) #[N, 1]
    norm_A = A * deg_inv_sqrt * deg_inv_sqrt.T
    
    row, col, val = norm_A.coo()
    return torch.stack([row, col]), val.view(-1, 1, 1)