from sklearn import metrics
from sklearn.metrics import average_precision_score
from ogb.linkproppred import PygLinkPropPredDataset, Evaluator
import torch
import numpy as np



import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import pickle
def save_data_Embedding(all_embeddings,all_targets,data_name, dir_to_save):

    # Variables to save
    data_to_save = {
        'embeddings': all_embeddings,
        'targets': all_targets,
    }

    # Save to a .dat file
    with open(dir_to_save + '/tsne_data_' + data_name + '.dat', 'wb') as f:
        pickle.dump(data_to_save, f)


def draw_TSNE_embeding_nodeclass(dataname, node_features, node_labels, save_path=None):
    # Apply t-SNE to reduce dimensions to 2D
    tsne = TSNE(n_components=2, random_state=42)
    node_embeddings_2d = tsne.fit_transform(node_features)

    # Plot the t-SNE results
    plt.figure(figsize=(8, 8))
    scatter = plt.scatter(node_embeddings_2d[:, 0], node_embeddings_2d[:, 1], c=node_labels, cmap="bwr", s=10)
    # scatter = plt.scatter(node_embeddings_2d[:, 0], node_embeddings_2d[:, 1], c=node_labels, cmap="tab10", s=10)
    plt.colorbar(scatter)
    plt.title("t-SNE Visualization of " + dataname + " Dataset")

    # Save the plot if a save path is provided
    if save_path:
        plt.savefig(save_path, format='png', dpi=300)  # Adjust format and dpi as needed
        print(f"Plot saved to {save_path}")

    # Display the plot
    plt.show()

def evaluate_hits(evaluator, pos_pred, neg_pred, k_list):
    results = {}
    for K in k_list:
        evaluator.K = K
        hits = evaluator.eval({
            'y_pred_pos': pos_pred,
            'y_pred_neg': neg_pred,
        })[f'hits@{K}']
        # test_hits = evaluator.eval({
        #     'y_pred_pos': pos_test_pred,
        #     'y_pred_neg': neg_test_pred,
        # })[f'hits@{K}']
        hits = round(hits, 4)
        # test_hits = round(test_hits, 4)

        results[f'Hits@{K}'] = hits
    return results


def hits_at_n_ogb(scores, targets, k_list):
    evaluator_hit = Evaluator(name='ogbl-collab')
    pos_test_pred = scores[targets == 1]
    neg_test_pred = scores[targets == 0]
    result_hit_test = evaluate_hits(evaluator_hit, pos_test_pred, neg_test_pred, k_list)
    return result_hit_test


def evaluate_mrr_scaled(scores, targets):
    evaluator_mrr = Evaluator(name='ogbl-citation2')
    pos_test_pred = torch.tensor(scores[targets == 1])
    neg_test_pred = torch.tensor(scores[targets == 0])
    neg_test_pred = neg_test_pred.view(pos_test_pred.shape[0], -1)
    test_mrr = evaluator_mrr.eval({
        'y_pred_pos': pos_test_pred,
        'y_pred_neg': neg_test_pred,
    })['mrr_list'].mean().item()

    return test_mrr


