from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
import torch
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.negative_sampler import NegativeEdgeSampler
from tqdm import tqdm
import argparse
from typing import Dict, List, Tuple
import os
import json
from simple_baseline.simple_run_scripts.constants import VALID, TEST, BASEFOLDER, BIPARTITE_DATASETS, TRACKER_CHOICES, DATASET_COL, TRACKER_COL
from simple_baseline.simple_run_scripts.helpers import get_min_max_time, get_split, update_sequences, evaluate_sequences, get_edges_by_time, ensure_correct_dataset_conversion, convert_metrics_to_results
import simple_baseline.simple_run_scripts.trackers as trackers
from collections import defaultdict


def create_test_split_sequences(
        train_edges_by_time : Dict[int, List[Tuple[int, int]]],
        val_edges_by_time : Dict[int, List[Tuple[int, int]]],
        test_edges_by_time : Dict[int, List[Tuple[int, int]]],
        tracker_type : str,
        evaluator : Evaluator, 
        neg_sampler : NegativeEdgeSampler, 
        timesplits : Dict[str, int], 
        num_nodes : int,
        metric="mrr",
        bruteforce_test : bool=False,
        do_full_mrr : bool=True
    ):
    """
    Streamer that creates sequences for each node and evaluates the model at each timestep
    Note that evaluation happens *before* the next timestep is added to the sequence
    Otherwise we would have data leakage. 

    Args:
    - edges_by_time : Dict[int, List[Tuple[int, int]]] : A dictionary where the key is the time and the value is a list of (src, dst) edges
    - evaluator : Evaluator : An evaluator object created by TGB authors. 
    - neg_sampler : NegativeEdgeSampler : A negative edge sampler for evaluation created by TGB authors.
    - timesplits : Dict[str, int] : A dictionary with the following keys:
        - train_min_time : int : The minimum time in the training set
        - train_max_time : int : The maximum time in the training set
        - val_min_time : int : The minimum time in the validation set
        - val_max_time : int : The maximum time in the validation set
        - test_min_time : int : The minimum time in the test set
        - test_max_time : int : The maximum time in the test set
    - metric : str : The metric to evaluate on (default is "mrr")
    """
    
    all_timestamps = sorted(train_edges_by_time.keys()) + sorted(val_edges_by_time.keys()) + sorted(test_edges_by_time.keys())
    num_edges = sum([len(x) for x in train_edges_by_time.values()]) + sum([len(x) for x in val_edges_by_time.values()]) + sum([len(x) for x in test_edges_by_time.values()])
    max_t = max(all_timestamps)
    all_timestamps = {x : idx for idx, x in enumerate(all_timestamps)}
    tracker = getattr(trackers, tracker_type)(all_timestamps, num_nodes, num_edges)
    do_full_mrr = hasattr(tracker, "get_full_mrr") and do_full_mrr
    # Ranking scores
    rankings = {VALID : defaultdict(list), TEST : defaultdict(list)}
    full_rankings = {VALID : defaultdict(list), TEST : defaultdict(list)}
    train_times = sorted(train_edges_by_time.keys())
    # Fill up with the training data
    for t in tqdm(train_times, desc="Training"):
        update_sequences(train_edges_by_time, tracker, t)

    validation_times = sorted(val_edges_by_time.keys())

    for t in tqdm(validation_times, desc="Validation"):
        split = get_split(t, timesplits)
        assert split == VALID
        new_metrics = evaluate_sequences(tracker, val_edges_by_time, evaluator, neg_sampler, t, split, max_t=max_t, metric=metric, num_nodes=num_nodes, do_full_mrr=do_full_mrr, bruteforce_test=bruteforce_test)
        for metric, value in new_metrics.items():
            rankings[VALID][metric].extend(value)
        
        update_sequences(val_edges_by_time, tracker, t)


    test_times = sorted(test_edges_by_time.keys())
    for t in tqdm(test_times, desc="Testing"):
        split = get_split(t, timesplits)
        assert split == TEST
        new_metrics = evaluate_sequences(tracker, test_edges_by_time, evaluator, neg_sampler, t, split, max_t=max_t, metric=metric, num_nodes=num_nodes, do_full_mrr=do_full_mrr, bruteforce_test=bruteforce_test)
        for metric, value in new_metrics.items():
            rankings[TEST][metric].extend(value)
        update_sequences(test_edges_by_time, tracker, t)
        
    return tracker, rankings, full_rankings

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="tgbl-wiki")
    parser.add_argument("--tracker", type=str, default="GlobalRecencyTracker", choices=TRACKER_CHOICES)
    parser.add_argument("--bruteforce-test", action="store_true", help="Bruteforce the test set")
    parser.add_argument("--skip-full-ranks", action="store_true", help="Skip full ranks")
    # DATA = "tgbl-wiki"
    args = parser.parse_args()
    DATA = args.dataset
    print(f"Running tracker {args.tracker} on dataset {DATA}")
    device = torch.device("cpu")
    dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
    train_mask = dataset.train_mask
    val_mask = dataset.val_mask
    test_mask = dataset.test_mask

    assert (torch.stack([train_mask, val_mask, test_mask]).sum(dim=0) == 1).all()

    data = dataset.get_TemporalData()
    data = data.to(device)
    metric = dataset.eval_metric
    train_data = data[train_mask]
    val_data = data[val_mask]
    test_data = data[test_mask]
    if DATA in BIPARTITE_DATASETS:
        num_nodes = data.dst.unique().shape[0]
    else:
        num_nodes = torch.cat((data.src.unique(), data.dst.unique())).unique().shape[0]

    train_min_time, train_max_time = get_min_max_time(train_data)
    val_min_time, val_max_time = get_min_max_time(val_data)
    test_min_time, test_max_time = get_min_max_time(test_data)
    timesplits= {
        "train_min_time" : train_min_time,
        "train_max_time" : train_max_time,
        "val_min_time" : val_min_time,
        "val_max_time" : val_max_time,
        "test_min_time" : test_min_time,
        "test_max_time" : test_max_time
    }

    evaluator = Evaluator(name=DATA)
    neg_sampler = dataset.negative_sampler

    
    train_edges_by_time = get_edges_by_time(train_data.src, train_data.dst, train_data.t)
    val_edges_by_time = get_edges_by_time(val_data.src, val_data.dst, val_data.t)
    test_edges_by_time = get_edges_by_time(test_data.src, test_data.dst, test_data.t)
    
    ensure_correct_dataset_conversion(train_edges_by_time, val_edges_by_time, test_edges_by_time, train_data, val_data, test_data)

    dataset.load_val_ns()
    dataset.load_test_ns()
    _, rankings, _ = create_test_split_sequences(train_edges_by_time, 
                                              val_edges_by_time,
                                              test_edges_by_time,
                                              tracker_type=args.tracker,
                                              evaluator=evaluator, neg_sampler=neg_sampler, metric=metric, timesplits=timesplits, num_nodes=num_nodes, 
                                              bruteforce_test=args.bruteforce_test,
                                              do_full_mrr=(not args.skip_full_ranks)
                                              )
    results = {}

    results[DATASET_COL] = DATA
    results[TRACKER_COL] = args.tracker

    convert_metrics_to_results(rankings, results, val_data, VALID)
    convert_metrics_to_results(rankings, results, test_data, TEST)
    
    with open(os.path.join(BASEFOLDER, "results", f"{DATA}_{args.tracker}_regular.json"), "w") as f:
        json.dump(results, f)

if __name__ == "__main__":
    main()