import benchtemp as bt
import torch
import numpy as np
from typing import List, Iterable
import torch
import numpy as np
import argparse
import os
import json
from simple_baseline.simple_run_scripts.constants import (
    VALID, TEST, BASEFOLDER, BIPARTITE_DATASETS, 
    PERCENTILES, MRR, AP, ROC_AUC, INDUCTIVE, INDUCTIVE_NEW_NEW, 
    TRACKER_CHOICES, BENCHTEMP_DATASETS, INDUCTIVE_NEW_OLD,
    POSITIVE_PROBS, NEGATIVE_PROBS, TRANSDUCTIVE, AVG, STDEV, NUM_SAMPLES
)
from sklearn.metrics import average_precision_score, roc_auc_score
from simple_baseline.simple_run_scripts.helpers import get_min_max_time, get_edges_by_time, load_bt_data, is_bipartite, create_test_split_sequences_bt, ensure_correct_dataset_conversion, convert_metrics_to_results
from simple_baseline.simple_run_scripts.helpers import develop_ranks_from_tuples
import pandas as pd
from collections import defaultdict

def neg_iterator_yielder(num_examples : int, num_negs_per_example : int, high : int, low : int = 0):
    random_negatives = np.random.randint(low=low, high=high, size=(num_examples, num_negs_per_example))
    for idx in range(len(random_negatives)):
        yield random_negatives[idx]



def _evaluate_sampled_metrics(pos_probs : List[float], neg_probs : List[float], batchsize : int = 30):
    if isinstance(pos_probs[0], Iterable):
        num_pos, num_neg = len(pos_probs), len(neg_probs)
        # breakpoint()
        numnegs = len(neg_probs[0])
        neg_probs_old = neg_probs
        all_negs = np.array(neg_probs).reshape(num_pos, numnegs, -1)
        num_heuristics = all_negs.shape[-1]
        all_negs = all_negs.reshape(-1, num_heuristics)
        all_negs = [tuple(x) for x in all_negs.tolist()]
        # breakpoint()
        scores = develop_ranks_from_tuples([tuple(x) for x in pos_probs] + all_negs)
        pos_probs = scores[:num_pos]
        neg_probs = scores[num_pos:]
        assert num_neg == num_pos
        assert len(pos_probs) == num_pos
        assert len(neg_probs) == num_pos*numnegs
        neg_probs = np.array(neg_probs).reshape(num_pos, numnegs)
        # np.array(all_negs)[5:10]
        # neg_probs[:3]
        # neg_probs_old[1]
        # breakpoint()
    else:
        numnegs = len(neg_probs[0])
    # true_label = np.concatenate((np.ones(len(pos_probs)), np.zeros_like(np.array(neg_probs))))
    # pos_labels = np.ones(len(pos_probs))
    # breakpoint()
    # pred_score = np.concatenate((pos_probs, neg_probs))
    # return {AP : average_precision_score(true_label, pred_score), ROC_AUC : roc_auc_score(true_label, pred_score)}
    # Simulate the batch evaluation done in benchtemp
    assert numnegs == 3
    neg_probs = np.array(neg_probs)
    metrics = defaultdict(list)
    for i in range(0, len(pos_probs), batchsize):
        all_aps = []
        all_roc_aucs = []
        for j in range(numnegs):
            batch_start = i
            batch_end = min(i+batchsize, len(pos_probs))
            pos_batch = pos_probs[batch_start:batch_end]
            neg_batch = neg_probs[batch_start:batch_end, j]
            assert len(pos_batch) == len(neg_batch)
            true_label = np.concatenate((np.ones(len(pos_batch)), np.zeros(len(neg_batch))))
            pred_score = np.concatenate((pos_batch, neg_batch))
            all_aps.append(average_precision_score(true_label, pred_score))
            all_roc_aucs.append(roc_auc_score(true_label, pred_score))
        metrics[AP].append(all_aps)
        metrics[ROC_AUC].append(all_roc_aucs)
    metrics = {met : np.mean(np.array(values),axis=0) for met, values in metrics.items()}
    metrics = {met : {AVG : np.mean(values), STDEV : np.std(values), NUM_SAMPLES : numnegs} for met, values in metrics.items()}
    return metrics


def divide_into_splits(pos_probs : List[float], neg_probs : List[float], edges_by_time, marks : List[int], split):
    
    timestamps = sorted(list(edges_by_time.keys()))
    current_edge = 0


    probs_by_setting = {
        key : {POSITIVE_PROBS : [], NEGATIVE_PROBS : []}
        for key in [TRANSDUCTIVE, INDUCTIVE, INDUCTIVE_NEW_NEW, INDUCTIVE_NEW_OLD]
    }
    
    for t in timestamps:
        edges = edges_by_time[t]
        
        for src, dst in edges:
            if (src, dst, t) in marks[INDUCTIVE]:
                probs_by_setting[INDUCTIVE][POSITIVE_PROBS].append(pos_probs[current_edge])
                probs_by_setting[INDUCTIVE][NEGATIVE_PROBS].append(neg_probs[current_edge])
                if (src, dst, t) in marks[INDUCTIVE_NEW_NEW]:
                    probs_by_setting[INDUCTIVE_NEW_NEW][POSITIVE_PROBS].append(pos_probs[current_edge])
                    probs_by_setting[INDUCTIVE_NEW_NEW][NEGATIVE_PROBS].append(neg_probs[current_edge])
                else:
                    probs_by_setting[INDUCTIVE_NEW_OLD][POSITIVE_PROBS].append(pos_probs[current_edge])
                    probs_by_setting[INDUCTIVE_NEW_OLD][NEGATIVE_PROBS].append(neg_probs[current_edge])
                
            else:
                probs_by_setting[TRANSDUCTIVE][POSITIVE_PROBS].append(pos_probs[current_edge])
                probs_by_setting[TRANSDUCTIVE][NEGATIVE_PROBS].append(neg_probs[current_edge])
            current_edge += 1
    assert current_edge == len(pos_probs)
    return probs_by_setting

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="mooc", choices=BENCHTEMP_DATASETS)
    parser.add_argument("--tracker", type=str, default="GlobalRecencyTracker", choices=TRACKER_CHOICES)
    parser.add_argument("--numnegs", type=int, default=3)
    parser.add_argument("--bruteforce-test", action="store_true", help="Bruteforce the rankings and ensure it complies with the estimates. ")
    parser.add_argument("--saveranks", action="store_true", help="Save the ranks for debugging purposes")

    # DATA = "tgbl-wiki"
    args = parser.parse_args()
    DATA = args.dataset
    print(f"Running tracker {args.tracker} on dataset {DATA}")
    

    data, train_data, val_data, test_data, marks = load_bt_data(DATA)
    if DATA in BIPARTITE_DATASETS or is_bipartite(data):
        num_nodes = data.dst.unique().shape[0]
        dst_candidates = data.dst.unique()
    else:
        num_nodes = torch.cat((data.src.unique(), data.dst.unique())).unique().shape[0]
        dst_candidates = torch.cat((data.src, data.dst)).unique()

    train_min_time, train_max_time = get_min_max_time(train_data)
    val_min_time, val_max_time = get_min_max_time(val_data)
    test_min_time, test_max_time = get_min_max_time(test_data)
    timesplits= {
        "train_min_time" : train_min_time,
        "train_max_time" : train_max_time,
        "val_min_time" : val_min_time,
        "val_max_time" : val_max_time,
        "test_min_time" : test_min_time,
        "test_max_time" : test_max_time
    }

    evaluator = bt.Evaluator("LP")
    neg_sampler = bt.lp.RandEdgeSampler(val_data.src.numpy(), val_data.dst.numpy())
    min_dst, max_dst = data.dst.min().item(), data.dst.max().item()
    neg_iterator = neg_iterator_yielder(num_examples=len(val_data.src)+len(test_data.src), num_negs_per_example=args.numnegs, high=max_dst, low=min_dst)
    
    train_edges_by_time = get_edges_by_time(train_data.src, train_data.dst, train_data.t)
    val_edges_by_time = get_edges_by_time(val_data.src, val_data.dst, val_data.t)
    test_edges_by_time = get_edges_by_time(test_data.src, test_data.dst, test_data.t)
    
    ensure_correct_dataset_conversion(train_edges_by_time, val_edges_by_time, test_edges_by_time, train_data, val_data, test_data)
    metric = MRR
    _, rankings, _ = create_test_split_sequences_bt(train_edges_by_time, 
                                              val_edges_by_time,
                                              test_edges_by_time,
                                              tracker_type=args.tracker,
                                              evaluator=evaluator, neg_sampler=neg_sampler, metric=metric, timesplits=timesplits, num_nodes=num_nodes, neg_iterator=neg_iterator, bruteforce_test=args.bruteforce_test, dst_candidates=(dst_candidates if args.bruteforce_test else None))
    results = {}

    results["Dataset"] = DATA
    results["Tracker"] = args.tracker

    convert_metrics_to_results(rankings, results, val_data, VALID)
    convert_metrics_to_results(rankings, results, test_data, TEST)
    # Now do the negative sampling evaluation
    # breakpoint()
    # Now divide the pos_probs and neg_probs into the different groups marked
    edges_by_time = {
        VALID : val_edges_by_time,
        TEST : test_edges_by_time
    }
    split_sets = {}
    for split in [VALID, TEST]:
        
        split_sets[split] = divide_into_splits(
            pos_probs=rankings[split][POSITIVE_PROBS], 
            neg_probs=rankings[split][NEGATIVE_PROBS], 
            edges_by_time=edges_by_time[split], 
            marks=marks[split], 
            split=split
        )
    for split in [VALID, TEST]:
        for setting in split_sets[split]:
            results[split][setting] = _evaluate_sampled_metrics(split_sets[split][setting][POSITIVE_PROBS], split_sets[split][setting][NEGATIVE_PROBS], batchsize = 30)
    
    with open(os.path.join(BASEFOLDER, "results", f"{DATA}_{args.tracker}_regular_with_sampled.json"), "w") as f:
        json.dump(results, f)

    if args.saveranks is True:
        with open(os.path.join(BASEFOLDER, "debugdata", f"{DATA}_{args.tracker}_ranks.json"), "w") as f:
            json.dump(rankings, f)

if __name__ == "__main__":
    main()