import time
import argparse
from torch_geometric.loader import TemporalDataLoader


from tgb.linkproppred.negative_generator import NegativeEdgeGenerator
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
from tgb.utils.utils import save_pkl


def generate_neg(dataset_name):
    r"""
    Generate negative edges for the validation or test phase
    """
    print("*** Negative Sample Generation ***")

    # setting the required parameters
    num_neg_e_per_pos = 1 #20 #100
    neg_sample_strategy = "hist_rnd" #"rnd"
    rnd_seed = 42
    batch_size = 200
    save_file = True
    
    filename = f"artefacts/{dataset_name}_recent-neg_B{batch_size}.pkl"
    dataset = PyGLinkPropPredDataset(name=dataset_name, root="datasets")
    test_mask = dataset.test_mask
    data = dataset.get_TemporalData()

    test_split = data[test_mask]

    # Ensure to only sample actual destination nodes as negatives.
    min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
    n_edge = len(data)
    cur_idx = n_edge - test_mask.sum()

    loader = TemporalDataLoader(test_split, batch_size=batch_size)
    neg_set = {}
    print(
        f"INFO: Start generating test negative samples: {dataset_name} --- {neg_sample_strategy}"
    )
    start_time = time.time()
    for pos_batch in loader:
        batch_start_time = time.time()
        historical_data = data[:cur_idx]
        
        
        neg_edge_generator = NegativeEdgeGenerator("none",
                                                first_dst_id=min_dst_idx,
                                                last_dst_id=max_dst_idx,
                                                num_neg_e=num_neg_e_per_pos, 
                                                strategy=neg_sample_strategy,
                                                hist_ratio=1.0,
                                                rnd_seed=rnd_seed,
                                                historical_data=historical_data,
                                                save_to_file=False)

        # generate test negative edge set        
        neg_set_batch = neg_edge_generator.generate_negative_samples(pos_batch, split_mode="test", partial_path="./")
        neg_set = neg_set | neg_set_batch

        # update cur idx
        cur_idx += len(pos_batch)

        if len(neg_set)%50000 == 0:
            fname = f"{filename.split('.')[0]}_temp.pkl"
            print("dumping intermediate set:",fname)
            save_pkl((neg_set, cur_idx), fname)

        print(f"INFO: End of batch upto {cur_idx}/{n_edge}. Elapsed Time (s): {time.time() - batch_start_time: .4f}")

    
    print(
        f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
    )
    if save_file:
        save_pkl(neg_set, filename)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Optional app description')
    parser.add_argument("-d", "--dataset", type=str, default="tgbl-wiki")
    args = parser.parse_args()
    generate_neg(args.dataset)
