import numpy as np
from collections import defaultdict

from utils.DataLoader import Data

def time_compare_link_prediction(
                                 node_interact_times: np.ndarray,
                                 neighbor_sampler: Data,
                                 positive_edges: tuple, negative_edges: tuple,
                                 positive_random_edges: tuple, negative_random_edges: tuple):

    # shape (batch_size, )
    positive_edges_times = compute_edge_time(positive_edges, neighbor_sampler)
    negative_edges_times = compute_edge_time(negative_edges, neighbor_sampler)

    positive_random_edges_times = np.mean(compute_edge_time(positive_random_edges, neighbor_sampler))
    negative_random_edges_times = np.mean(compute_edge_time(negative_random_edges, neighbor_sampler))

    # shape (batch_size, )
    positive_probabilities = np.where(positive_edges_times <= positive_random_edges_times, 1.0, 0.0)
    negative_probabilities = np.where(negative_edges_times < negative_random_edges_times, 1.0, 0.0)

    return positive_probabilities, negative_probabilities

def compute_edge_time(edges: tuple, neighbor_sampler: Data):
    edge_times = []
    src_nodes_ids, dst_nodes_ids, node_interact_times = edges
    for idx, (src_node_id, dst_node_id, node_interact_time) in enumerate(zip(src_nodes_ids, dst_nodes_ids, node_interact_times)):
        _, _, src_neighbor_times, _ = \
            neighbor_sampler.find_neighbors_before(node_id=src_node_id, interact_time=node_interact_time)

        _, _, dst_neighbor_times, _ = \
            neighbor_sampler.find_neighbors_before(node_id=dst_node_id, interact_time=node_interact_time)

        if len(src_neighbor_times) + len(dst_neighbor_times) > 0:
            edge_time_average = node_interact_time - np.mean(np.concatenate([src_neighbor_times, dst_neighbor_times]))  # 取最近的一次交易
        else:
            edge_time_average = node_interact_time

        edge_times.append(edge_time_average)

    return np.array(edge_times)


"""def compute_edge_time(edges: tuple, neighbor_sampler: Data):
    edge_times = []
    src_nodes_ids, dst_nodes_ids, node_interact_times = edges
    for idx, (src_node_id, dst_node_id, node_interact_time) in enumerate(zip(src_nodes_ids, dst_nodes_ids, node_interact_times)):
        _, _, src_neighbor_times, _ = \
            neighbor_sampler.find_neighbors_before(node_id=src_node_id, interact_time=node_interact_time)

        _, _, dst_neighbor_times, _ = \
            neighbor_sampler.find_neighbors_before(node_id=dst_node_id, interact_time=node_interact_time)

        if len(src_neighbor_times) > 0:
            src_time_interval = node_interact_time - src_neighbor_times[-1]
        else:
            src_time_interval = node_interact_time

        if len(dst_neighbor_times) > 0:
            dst_time_interval = node_interact_time - dst_neighbor_times[-1]
        else:
            dst_time_interval = node_interact_time

        edge_time_average = (src_time_interval + dst_time_interval) / 2
        edge_times.append(edge_time_average)

    return np.array(edge_times)"""