import numpy as np
from tqdm import tqdm
import math

# internal imports
from tgb.utils.utils import set_random_seed
from torch_geometric.loader import TemporalDataLoader
from utils.adv_attacks import add_negative_neighbor

# ==================
# ==================
# ==================

def test(edgebank, data, mask, bs, neg_sampler, split_mode, metric, evaluator,
         adv_edges_param=None):
    r"""
    Evaluated the dynamic link prediction
    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges

    Parameters:
        data: a dataset object
        test_mask: required masks to load the test set edges
        neg_sampler: an object that gives the negative edges corresponding to each positive edge
        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
    Returns:
        perf_metric: the result of the performance evaluaiton
    """
    # num_batches = math.ceil(len(data['sources'][mask]) / bs)
    
    # perf_list = []
    # for batch_idx in tqdm(range(num_batches)):
    #     start_idx = batch_idx * bs
    #     end_idx = min(start_idx + bs, len(data['sources'][mask]))
    #     pos_src, pos_dst, pos_t = (
    #         data['sources'][mask][start_idx: end_idx],
    #         data['destinations'][mask][start_idx: end_idx],
    #         data['timestamps'][mask][start_idx: end_idx],
    #     )

    #     pos_msg = data["edge_feat"][mask][start_idx: end_idx]


    # num_batches = math.ceil(len(data.src[mask]) / bs)
    
    # perf_list = []
    # for batch_idx in tqdm(range(num_batches)):
    #     start_idx = batch_idx * bs
    #     end_idx = min(start_idx + bs, len(data.src[mask]))
    #     pos_src, pos_dst, pos_t = (
    #         data.src[mask][start_idx: end_idx],
    #         data.dst[mask][start_idx: end_idx],
    #         data.t[mask][start_idx: end_idx],
    #     )

    #     pos_msg = data.msg[mask][start_idx: end_idx]

    split_data = data[mask]
    loader = TemporalDataLoader(split_data, batch_size=bs)
    perf_list = []
    for pos_batch in loader:
        pos_src, pos_dst, pos_t, pos_msg = (
            pos_batch.src,
            pos_batch.dst,
            pos_batch.t,
            pos_batch.msg,
        )
        
        if adv_edges_param and adv_edges_param[-1]=="front":
            
            (adv_edges, time_sd, budget, att_pos) = adv_edges_param

            # adv attack with -ve edges
            n = round(len(pos_src) * budget)
            if n>0:
                adv_sources, adv_dests, adv_t, adv_msgs = add_negative_neighbor(adv_edges, time_sd, n, pos_msg.shape[-1], pos_t[0], 
                                                                                device="cpu")

                

                # update edgebank memory before each positive batch
                edgebank.update_memory(adv_sources.numpy(), adv_dests.numpy(), adv_t.numpy())

                # add adversarial to the bank
                batch_pos = {(s.reshape(-1),d.reshape(-1),t.reshape(-1)):m.reshape(1,-1) for (s,d,t,m) in zip(adv_sources, adv_dests, adv_t, adv_msgs)}
                adv_edges = adv_edges | batch_pos

                # # add positives of current batch to the bank
                batch_pos = {(s.reshape(-1),d.reshape(-1),t.reshape(-1)):m.reshape(1,-1) for (s,d,t,m) in zip(pos_src, pos_dst, pos_t, pos_msg)}
                adv_edges = adv_edges | batch_pos

        
        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
        
        for idx, neg_batch in enumerate(neg_batch_list):
            query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
            query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
            
            y_pred = edgebank.predict_link(query_src, query_dst)
            # compute MRR
            input_dict = {
                "y_pred_pos": np.array([y_pred[0]]),
                "y_pred_neg": np.array(y_pred[1:]),
                "eval_metric": [metric],
            }
            perf_list.append(evaluator.eval(input_dict)[metric])
            
        # update edgebank memory after each positive batch
        edgebank.update_memory(pos_src.numpy(), pos_dst.numpy(), pos_t.numpy())
        if adv_edges_param and adv_edges_param[1]=="back":
            
            (adv_edges, time_sd, budget,att_pos) = adv_edges_param

            # adv attack with -ve edges
            n = round(len(pos_src) * budget)
            if n>0:
                adv_sources, adv_dests, adv_t, adv_msgs = add_negative_neighbor(adv_edges, time_sd, n, pos_msg.shape[-1], pos_t[0], 
                                                                                device="cpu")

                

                # update edgebank memory before each positive batch
                edgebank.update_memory(adv_sources.numpy(), adv_dests.numpy(), adv_t.numpy())

                # # add positives of current batch to the bank
                batch_pos = {(s.reshape(-1),d.reshape(-1),t.reshape(-1)):m.reshape(1,-1) for (s,d,t,m) in zip(pos_src, pos_dst, pos_t, pos_msg)}
                adv_edges = adv_edges | batch_pos

                # add adversarial to the bank
                batch_pos = {(s.reshape(-1),d.reshape(-1),t.reshape(-1)):m.reshape(1,-1) for (s,d,t,m) in zip(adv_sources, adv_dests, adv_t, adv_msgs)}
                adv_edges = adv_edges | batch_pos
                
    perf_metrics = float(np.mean(perf_list))

    return perf_metrics


