
from copy import deepcopy
from sklearn.metrics import accuracy_score, r2_score
from torch_kmeans import KMeans
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from sklearn.cluster import KMeans
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score,silhouette_score,calinski_harabasz_score,davies_bouldin_score
from torch_geometric.data import InMemoryDataset, Data
import numpy as np
from utils import do_edge_split_direct, edgemask_um
from torch_geometric.utils import to_undirected, add_self_loops, negative_sampling
from torch_sparse import SparseTensor
from model.decoder import LPDecoder
from sklearn.metrics import roc_auc_score, average_precision_score
from torch.utils.data import DataLoader
from torch_geometric.utils import dense_to_sparse


def evaluate_clustering(embeddings, true_labels, n_clusters):
    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
    cluster_labels = kmeans.fit_predict(embeddings)
    nmi = normalized_mutual_info_score(true_labels, cluster_labels)
    ari = adjusted_rand_score(true_labels, cluster_labels)
    ss = silhouette_score(embeddings, cluster_labels)
    chi = calinski_harabasz_score(embeddings, cluster_labels)
    dbi = davies_bouldin_score(embeddings, cluster_labels)
    return nmi, ari, ss, chi, dbi


def evaluate_auc(train_pred, train_true, val_pred, val_true, test_pred, test_true):
    train_auc = roc_auc_score(train_true, train_pred)
    valid_auc = roc_auc_score(val_true, val_pred)
    test_auc = roc_auc_score(test_true, test_pred)
    train_ap = average_precision_score(train_true, train_pred)
    valid_ap = average_precision_score(val_true, val_pred)
    test_ap = average_precision_score(test_true, test_pred)
    results = dict()
    results['AUC'] = (train_auc, valid_auc, test_auc)
    results['AP'] = (train_ap, valid_ap, test_ap)
    return results


def train(model, predictor, data, split_edge, optimizer, args):
    model.train()
    predictor.train()

    adj, edge_index, edge_index_mask = edgemask_um(split_edge, data.x.device, data.x.shape[0])

    optimizer.zero_grad()
    pre_edge_index = adj.to(data.x.device)
    pos_train_edge = edge_index_mask

    # dense_adj = pre_edge_index.to_dense()
    #
    # edge_index2 = dense_to_sparse(dense_adj)[0]

    h = model(data.x, pre_edge_index)

    edge = pos_train_edge

    pos_out = predictor(h, edge)


    pos_loss = -torch.log(pos_out + 1e-15).mean()

    new_edge_index, _ = add_self_loops(edge_index.cpu())
    edge = negative_sampling(new_edge_index, num_nodes=data.num_nodes, num_neg_samples=pos_train_edge.shape[1])

    edge = edge.to(data.x.device)


    neg_out = predictor(h, edge)
    neg_loss = -torch.log(1 - neg_out + 1e-15).mean()

    loss = pos_loss + neg_loss
    loss.backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0)

    optimizer.step()

    return loss.item()


@torch.no_grad()
def model_test(model, predictor, data, adj, split_edge, batch_size):
    model.eval()
    h = model(data.x, adj)

    pos_train_edge = split_edge['train']['edge'].to(data.x.device)
    neg_train_edge = split_edge['train']['edge_neg'].to(data.x.device)
    pos_valid_edge = split_edge['valid']['edge'].to(data.x.device)
    neg_valid_edge = split_edge['valid']['edge_neg'].to(data.x.device)
    # pos_valid_edge = split_edge['test']['edge'].to(data.x.device)
    # neg_valid_edge = split_edge['valid']['edge_neg'].to(data.x.device)


    pos_test_edge = split_edge['test']['edge'].to(data.x.device)
    neg_test_edge = split_edge['test']['edge_neg'].to(data.x.device)

    pos_train_preds = []
    for perm in DataLoader(range(pos_train_edge.size(0)), batch_size):
        edge = pos_train_edge[perm].t()
        pos_train_preds += [predictor(h, edge).squeeze().cpu()]
    pos_train_pred = torch.cat(pos_train_preds, dim=0)


    pos_valid_preds = []
    for perm in DataLoader(range(pos_valid_edge.size(0)), batch_size):
        edge = pos_valid_edge[perm].t()

        # print(h.shape, edge.shape)
        # print(edge[0])
        # src_x = [h[edge[0]]]
        # print(src_x)
        # dst_x = [h[edge[1]]]
        # print(dst_x)
        # print(predictor(h, edge).shape)
        # daw
        pos_valid_preds += [predictor(h, edge).squeeze().cpu()]
    pos_valid_pred = torch.cat(pos_valid_preds, dim=0)


    neg_train_preds = []
    for perm in DataLoader(range(neg_train_edge.size(0)), batch_size):
        edge = neg_train_edge[perm].t()
        neg_train_preds += [predictor(h, edge).squeeze().cpu()]
    neg_train_pred = torch.cat(neg_train_preds, dim=0)

    neg_valid_preds = []
    for perm in DataLoader(range(neg_valid_edge.size(0)), batch_size):
        edge = neg_valid_edge[perm].t()
        neg_valid_preds += [predictor(h, edge).squeeze().cpu()]
    neg_valid_pred = torch.cat(neg_valid_preds, dim=0)

    pos_test_preds = []
    for perm in DataLoader(range(pos_test_edge.size(0)), batch_size):
        edge = pos_test_edge[perm].t()
        pos_test_preds += [predictor(h, edge).squeeze().cpu()]
    pos_test_pred = torch.cat(pos_test_preds, dim=0)

    neg_test_preds = []
    for perm in DataLoader(range(neg_test_edge.size(0)), batch_size):
        edge = neg_test_edge[perm].t()
        neg_test_preds += [predictor(h, edge).squeeze().cpu()]
    neg_test_pred = torch.cat(neg_test_preds, dim=0)

    train_pred = torch.cat([pos_train_pred, neg_train_pred], dim=0)
    train_true = torch.cat([torch.ones_like(pos_train_pred), torch.zeros_like(neg_train_pred)], dim=0)

    val_pred = torch.cat([pos_valid_pred, neg_valid_pred], dim=0)
    val_true = torch.cat([torch.ones_like(pos_valid_pred), torch.zeros_like(neg_valid_pred)], dim=0)

    test_pred = torch.cat([pos_test_pred, neg_test_pred], dim=0)
    test_true = torch.cat([torch.ones_like(pos_test_pred), torch.zeros_like(neg_test_pred)], dim=0)

    results = evaluate_auc(train_pred, train_true, val_pred, val_true, test_pred, test_true)
    #results = evaluate_auc(train_pred, train_true, train_pred, train_true, test_pred, test_true)
    return results


class Link_pred:

    def __init__(self, args, ori_data, con_data):
        self.args = args
        self.ori_data = ori_data
        self.con_data = con_data
        self.dim = ori_data.x.shape[1]
        self.num_classes = ori_data.num_classes

    def model_train(self):
        ori_data =self.ori_data
        args = self.args

        if args.dataset == 'pubmed' and args.gc_method == 'pgc':# and args.reduction_rate == 0.25:
            if args.reduction_rate == 0.25:
                mask = self.con_data.edge_weight >= 0.012
            else:
                mask = self.con_data.edge_weight >= 0.005
            filtered_edge_index = self.con_data.edge_index[:, mask]
            filtered_edge_weight = self.con_data.edge_weight[mask]
            new_data = self.con_data.clone()
            new_data.edge_index = filtered_edge_index
            new_data.edge_weight = filtered_edge_weight
            self.con_data = new_data

        split_edge = do_edge_split_direct(ori_data, val_ratio=0.1, test_ratio=0.05)
        split_edge_con = do_edge_split_direct(self.con_data, val_ratio=0, test_ratio=0)

        edge_index_with_self_loops, _ = add_self_loops(split_edge['train']['edge'].t(), num_nodes=ori_data.x.shape[0])


        ori_data.edge_index = to_undirected(edge_index_with_self_loops, num_nodes=ori_data.x.shape[0])

        ori_adj = SparseTensor.from_edge_index(ori_data.edge_index).t()

        model_module = getattr(__import__('model', fromlist=[args.model]), args.model)
        model_class = getattr(model_module, args.model.upper())
        model = model_class(args, self.dim, self.args.hidden_dim, self.args.hidden_dim).cuda()

        predictor = LPDecoder(self.args.hidden_dim, self.args.hidden_dim, 1, args.dropout).cuda()

        best_valid = 0.0
        epochs = args.epochs
        lr = args.lr
        optimizer = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()), lr=args.lr)


        #print('start training')
        for i in range(epochs):
            if i == epochs // 2 and i > 0:
                lr = lr * 0.5
                optimizer = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()), lr=args.lr)
            model.train()
            predictor.train()
            optimizer.zero_grad()

            loss = train(model, predictor, self.con_data, split_edge_con, optimizer, args)
            results = model_test(model, predictor, ori_data, ori_adj, split_edge, 2048)

            valid_hits = results['AUC'][1]
            if valid_hits > best_valid:
                best_valid = valid_hits
                best_epoch = i
                # torch.save(model.state_dict(), save_path_model)
                # torch.save(predictor.state_dict(), save_path_predictor)
                weights1 = deepcopy(model.state_dict())
                weights2 = deepcopy(predictor.state_dict())

            # if (i+1) % 50 == 0:
            #     print(f"Epoch: {i}, current results: {valid_hits:.5f}, best results: {best_valid:.5f}, best epoch: {best_epoch}")

        all_results = {}
        model.load_state_dict(weights1)
        predictor.load_state_dict(weights2)
        test_data = ori_data
        results = model_test(model, predictor, ori_data, ori_adj, split_edge, 2048)
        model.eval()
        h = model(test_data.x, ori_adj)
        all_results['auc'] = results['AUC'][2]
        all_results['ap'] = results['AP'][2]
        print(f"Test set results: test_auc= {results['AUC'][2]:.5f}, test_ap = {results['AP'][2]:.5f}")

        return all_results
