import benchtemp as bt
import os
import torch
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.negative_sampler import NegativeEdgeSampler
from torch_geometric.data import TemporalData
import pandas as pd
from collections import defaultdict
from tqdm import tqdm
import numpy as np
from typing import Dict, List, Tuple, Iterator, Optional, Iterable
import simple_baseline.simple_run_scripts.trackers as trackers
from torch import Tensor, LongTensor
from simple_baseline.simple_run_scripts.constants import (
    TRAIN, VALID, TEST, MRR, FULL_MRR, RANK_OPT, RANK_PESS, 
    RANK_DIFF, TIME_RANKS, TIME_DELTAS, TIME_RANKS_NORMALIZED, 
    PERCENTILES, INDUCTIVE, INDUCTIVE_NEW_OLD, INDUCTIVE_NEW_NEW, 
    NEGATIVE_PROBS, POSITIVE_PROBS, BENCHTEMP_DATA_FOLDER
)

AP_INDEX = 0
ROC_INDEX = 1

# Set this to true if the optimistic rank should be the total number of nodes during investigations
# This gives a better plot and barely makes a differences to the results (?)
ALL_NODES_FOR_OPTIMISTIC_WORST_CASE = True

def _optimistic_worst_case_rank(total_in_group : int, num_nodes : int):
    """"""
    if ALL_NODES_FOR_OPTIMISTIC_WORST_CASE is True:
        return num_nodes
    return total_in_group

def get_min_max_time(data : TemporalData) -> Tuple[int, int]:
    """
    Given a TemporalData object, return the minimum and maximum time
    """
    return data.t.min().item(), data.t.max().item()

def get_edges_by_time(src_nodes : LongTensor, dst_nodes : LongTensor, times : LongTensor) -> Dict[int, List[Tuple[int, int]]]:
    edges_by_time = defaultdict(list)
    
    for i in tqdm(range(len(src_nodes)), desc="Creating edges by time"):
        src, dst, t = src_nodes[i].item(), dst_nodes[i].item(), times[i].item()
        edges_by_time[t].append((src, dst))
    return edges_by_time

def get_split(time : int, timesplits : dict) -> str:
    """
    Given a time, return the split that the time belongs to
    """
    if time < timesplits["val_min_time"]:
        return TRAIN
    if time >= timesplits["val_min_time"] and time < timesplits["test_min_time"]:
        return VALID
    assert time >= timesplits["test_min_time"]
    return TEST

def update_sequences(edges_by_time, tracker, t, disable_tqdm=True):
    for src, dst in edges_by_time[t]:
        tracker.update(src, dst, t)

def develop_ranks_from_tuples(scores):
    score2idx = {score : idx for idx, score in enumerate(sorted(set(scores)))}
    return [score2idx[score] for score in scores]

def evaluate_sequences(tracker,
                        edges_by_time : Dict[int, List[Tuple[int, int]]],
                        evaluator : Evaluator, 
                        neg_sampler : NegativeEdgeSampler, 
                        t : int, 
                        split : str, 
                        num_nodes : int,
                        max_t : int,
                        do_full_mrr : bool=True,
                        metric : str="mrr",
                        neg_iterator : Iterator = None,
                        bruteforce_test : bool = False
    ) -> List[float]:
    assert split in [VALID, TEST]
    rankings = []
    full_ranks = []
    opt_ranks = []
    pess_ranks = []
    rank_diffs = []
    node_ranks = []
    node_ranks_normalized = []
    node_timedeltas = []
    pos_probs = []
    neg_probs = []
    to_filter = defaultdict(set)
    for src, dst in edges_by_time[t]:
        to_filter[src].add(dst)
    
    # Making it a frozenset to make it immutable to ensure bugs do not occur. 
    to_filter = {k : frozenset(v) for k, v in to_filter.items()}

    for src, dst in edges_by_time[t]:
        if not tracker.contains_dst(src, dst):
            # If the true answer is not in the sequence, then we can't evaluate
            # And the model should receive the lower possible score (0)
            rankings.append(0)
            if tracker.tracker_score_type() == tuple:
                node_ranks.append([num_nodes]*len(tracker.trackers))
                node_timedeltas.append([max_t]*len(tracker.trackers))
                node_ranks_normalized.append([2]*len(tracker.trackers))
            else:
                node_ranks.append(num_nodes)
                node_timedeltas.append(max_t)
                node_ranks_normalized.append(2)
            if do_full_mrr:
                full_ranks.append(0)
                opt_ranks.append(_optimistic_worst_case_rank(tracker.get_total(src)+1, num_nodes))
                pess_ranks.append(num_nodes)
                rank_diffs.append(num_nodes)
                
            continue
        else:
            negs = neg_sampler.query_batch(np.array([src]), np.array([dst]), np.array([t]), split_mode=split)[0]
            neg_scores = [tracker.get_score(src, neg) for neg in negs]
            pos_score = tracker.get_score(src, dst)

            if tracker.tracker_score_type() == tuple:
                newscores = develop_ranks_from_tuples([pos_score]+neg_scores)
                positive_score = newscores[0]
                negative_scores = newscores[1:]
            else:
                positive_score, negative_scores = pos_score, neg_scores
                # assert all([neg_score < t for neg_score in neg_scores]), "Negative score is the same as the current time"
                # assert pos_score < t, "Positive score is the same as the current time"
            assert dst not in negs, "Negative sample is the same as the positive sample"
            
            input_dict = {
                "y_pred_pos" : np.array([positive_score]),
                "y_pred_neg" : np.array(negative_scores),
                "eval_metric" : [MRR]
            }
            rankings.append(evaluator.eval(input_dict)[MRR])

            
            
            trank, tdelta, totnum = tracker.get_node_rank(src, dst, t)
            node_ranks.append(trank)
            node_timedeltas.append(tdelta)
            if tracker.tracker_score_type() == tuple:
                node_ranks_normalized.append([tr/totn for tr, totn in zip(trank, totnum)])
            else:
                node_ranks_normalized.append(trank/totnum)
            if do_full_mrr:
                # print(src, dst, t, to_filter[src])
                full_mrr, rank_opt, rank_pess = tracker.get_full_mrr(src, dst, t, to_filter[src])
                # print("RANK PESS is ", rank_pess, "RANK OPT is ", rank_opt, "WITH NEW VALUES", src, dst, t, to_filter[src])
                # full_mrr, rank_opt, rank_pess = tracker.get_full_mrr(src, dst, t, None)
                if bruteforce_test:
                    try:
                        to_filter_out = np.isin(np.array(negs),np.array(list(to_filter[src])))
                        bruteforce_rank_opt = ((input_dict["y_pred_pos"] < input_dict["y_pred_neg"]) & ~to_filter_out).sum() + 1
                        bruteforce_rank_pess = ((input_dict["y_pred_pos"] <= input_dict["y_pred_neg"]) & ~to_filter_out).sum() + 1
                        # bruteforce_rank_opt = ((input_dict["y_pred_pos"] < input_dict["y_pred_neg"])).sum() + 1
                        # bruteforce_rank_pess = ((input_dict["y_pred_pos"] <= input_dict["y_pred_neg"])).sum() + 1
                        
                        # assert np.isclose(rankings[-1], full_mrr)
                        assert bruteforce_rank_opt == rank_opt
                        assert bruteforce_rank_pess == rank_pess
                    except AssertionError:
                        breakpoint()
                full_ranks.append(full_mrr)
                opt_ranks.append(rank_opt)
                pess_ranks.append(rank_pess)
                rank_diffs.append(rank_pess - rank_opt)
        
        if neg_iterator is not None:
            dst_neg = next(neg_iterator)
            pos_probs.append(node_ranks[-1])
            neg_probs.append(tracker.get_node_rank(src, dst_neg, t))

    ret = {"TGB MRR": rankings, TIME_RANKS : node_ranks, TIME_DELTAS : node_timedeltas, TIME_RANKS_NORMALIZED : node_ranks_normalized}
    if do_full_mrr:
        ret.update({FULL_MRR: full_ranks, RANK_OPT: opt_ranks, RANK_PESS: pess_ranks, RANK_DIFF: rank_diffs, TIME_RANKS : node_ranks, TIME_DELTAS : node_timedeltas, TIME_RANKS_NORMALIZED : node_ranks_normalized})
    if neg_iterator is not None:
        ret.update({POSITIVE_PROBS: pos_probs, NEGATIVE_PROBS: neg_probs})
    return {key : value for key, value in ret.items() if isinstance(value, Iterable) and len(value) > 0 or not isinstance(value, Iterable)}

def evaluate_sequences_benchtemp(tracker,
                        edges_by_time : Dict[int, List[Tuple[int, int]]],
                        # evaluator : bt.Evaluator, 
                        # neg_sampler : bt.lp.RandEdgeSampler, 
                        t : int, 
                        split : str, 
                        num_nodes : int,
                        max_t : int,
                        do_full_mrr : bool=True,
                        metric : str="mrr",
                        neg_iterator=None, 
                        bruteforce_test : bool = False,
                        dst_candidates : Optional[LongTensor] = None
    ) -> List[float]:
    assert split in [VALID, TEST]
    # rankings = []
    full_ranks = []
    opt_ranks = []
    pess_ranks = []
    rank_diffs = []
    node_ranks = []
    node_ranks_normalized = []
    node_timedeltas = []
    pos_probs = []
    neg_probs = []
    to_filter = defaultdict(set)
    tot_filtered = 0
    for src, dst in edges_by_time[t]:
        to_filter[src].add(dst)
    
    for src, dst in edges_by_time[t]:
        if not tracker.contains_dst(src, dst):
            # If the true answer is not in the sequence, then we can't evaluate
            # And the model should receive the lower possible score (0)
            # rankings.append(0)
            if do_full_mrr:
                full_ranks.append(0)
                opt_ranks.append(_optimistic_worst_case_rank(tracker.get_total(src)+1, num_nodes))
                pess_ranks.append(num_nodes)
                rank_diffs.append(num_nodes)
                if tracker.tracker_score_type() == tuple:
                    node_ranks.append([num_nodes]*len(tracker.trackers))
                    node_timedeltas.append([max_t]*len(tracker.trackers))
                    node_ranks_normalized.append([2]*len(tracker.trackers))
                else:
                    node_ranks.append(num_nodes)
                    node_timedeltas.append(max_t)
                    node_ranks_normalized.append(2)
            if neg_iterator is not None:
                
                dst_negs = next(neg_iterator)
                if tracker.tracker_score_type() == tuple:
                    pos_probs.append([-x for x in node_ranks[-1]])
                    neg_probs.append([[-x for x in tracker.get_node_rank(src, dst_neg, t)[0]] for dst_neg in dst_negs])
                else:
                    pos_probs.append(-node_ranks[-1])
                    neg_probs.append([-tracker.get_node_rank(src, dst_neg, t)[0] for dst_neg in dst_negs])
                # breakpoint()
            continue
        else:
            trank, tdelta, totnum = tracker.get_node_rank(src, dst, t)
            node_ranks.append(trank)
            node_timedeltas.append(tdelta)
            if tracker.tracker_score_type() == tuple:
                node_ranks_normalized.append([tr/totn for tr, totn in zip(trank, totnum)])
            else:
                node_ranks_normalized.append(trank/totnum)
            
            if do_full_mrr:
                full_mrr, rank_opt, rank_pess = tracker.get_full_mrr(src, dst, t, to_filter[src])
                tot_filtered += (len(to_filter[src]) - 1)
                # full_mrr, rank_opt, rank_pess = tracker.get_full_mrr(src, dst, t, None)
                if bruteforce_test:
                    assert dst_candidates is not None
                    pos_score = tracker.get_score(src, dst)
                    
                    neg_scores = [tracker.get_score(src, idx.item()) for idx in dst_candidates if idx.item() not in to_filter[src]]
                    # neg_scores = [tracker.get_score(src, neg) for neg in negs]
                    # breakpoint()
                    if tracker.tracker_score_type() == tuple:
                        newscores = develop_ranks_from_tuples([pos_score]+neg_scores)
                        positive_score = newscores[0]
                        negative_scores = newscores[1:]
                    else:
                        positive_score, negative_scores = pos_score, neg_scores
                    
                    
                    bruteforce_rank_opt = ((positive_score < np.array(negative_scores))).sum() + 1
                    bruteforce_rank_pess = ((positive_score <= np.array(negative_scores))).sum() + 1
                    try:
                        
                        
                        assert bruteforce_rank_opt == rank_opt
                        assert bruteforce_rank_pess == rank_pess
                    except AssertionError:
                        breakpoint()
                full_ranks.append(full_mrr)
                opt_ranks.append(rank_opt)
                pess_ranks.append(rank_pess)
                rank_diffs.append(rank_pess - rank_opt)
        if neg_iterator is not None:
            dst_negs = next(neg_iterator)
            # breakpoint()
            if tracker.tracker_score_type() == tuple:
                pos_probs.append([-x for x in node_ranks[-1]])
                neg_probs.append([[-x for x in tracker.get_node_rank(src, dst_neg, t)[0]] for dst_neg in dst_negs])
            else:
                pos_probs.append(-node_ranks[-1])
                neg_probs.append([-tracker.get_node_rank(src, dst_neg, t)[0] for dst_neg in dst_negs])


    ret = {TIME_RANKS : node_ranks, TIME_DELTAS : node_timedeltas, TIME_RANKS_NORMALIZED : node_ranks_normalized} # {"TGB MRR": rankings}
    if do_full_mrr:
        ret.update({FULL_MRR: full_ranks, RANK_OPT: opt_ranks, RANK_PESS: pess_ranks, RANK_DIFF: rank_diffs})
    if neg_iterator is not None:
        ret.update({POSITIVE_PROBS: pos_probs, NEGATIVE_PROBS: neg_probs})
    
    return {key : value for key, value in ret.items() if isinstance(value, Iterable) and len(value) > 0 or not isinstance(value, Iterable)}, tot_filtered


def create_test_split_sequences_bt(
        train_edges_by_time : Dict[int, List[Tuple[int, int]]],
        val_edges_by_time : Dict[int, List[Tuple[int, int]]],
        test_edges_by_time : Dict[int, List[Tuple[int, int]]],
        tracker_type : str,
        evaluator : bt.Evaluator, 
        neg_sampler : bt.lp.RandEdgeSampler, 
        timesplits : Dict[str, int], 
        num_nodes : int,
        metric="mrr",
        neg_iterator=None, 
        bruteforce_test : bool=False,
        dst_candidates : Optional[LongTensor] = None
    ):
    """
    Streamer that creates sequences for each node and evaluates the model at each timestep
    Note that evaluation happens *before* the next timestep is added to the sequence
    Otherwise we would have data leakage. 

    Args:
    - edges_by_time : Dict[int, List[Tuple[int, int]]] : A dictionary where the key is the time and the value is a list of (src, dst) edges
    - evaluator : Evaluator : An evaluator object created by TGB authors. 
    - neg_sampler : NegativeEdgeSampler : A negative edge sampler for evaluation created by TGB authors.
    - timesplits : Dict[str, int] : A dictionary with the following keys:
        - train_min_time : int : The minimum time in the training set
        - train_max_time : int : The maximum time in the training set
        - val_min_time : int : The minimum time in the validation set
        - val_max_time : int : The maximum time in the validation set
        - test_min_time : int : The minimum time in the test set
        - test_max_time : int : The maximum time in the test set
    - metric : str : The metric to evaluate on (default is "mrr")
    """
    all_timestamps = sorted(train_edges_by_time.keys()) + sorted(val_edges_by_time.keys()) + sorted(test_edges_by_time.keys())
    max_t = max(all_timestamps)
    all_timestamps = {x : idx for idx, x in enumerate(all_timestamps)}
    num_edges = sum([len(x) for x in train_edges_by_time.values()]) + sum([len(x) for x in val_edges_by_time.values()]) + sum([len(x) for x in test_edges_by_time.values()])
    tracker = getattr(trackers, tracker_type)(all_timestamps=all_timestamps, num_nodes=num_nodes, num_edges=num_edges)
    do_full_mrr = hasattr(tracker, "get_full_mrr")
    # Ranking scores
    rankings = {VALID : defaultdict(list), TEST : defaultdict(list)}
    full_rankings = {VALID : defaultdict(list), TEST : defaultdict(list)}
    train_times = sorted(train_edges_by_time.keys())
    # Fill up with the training data
    for t in tqdm(train_times, desc="Training"):
        update_sequences(train_edges_by_time, tracker, t)

    validation_times = sorted(val_edges_by_time.keys())
    tot_filtered_val, tot_filtered_test = 0, 0
    for t in tqdm(validation_times, desc="Validation"):
        split = get_split(t, timesplits)
        assert split == VALID
        
        new_metrics, num_filtered_val = evaluate_sequences_benchtemp(tracker, val_edges_by_time, # evaluator, neg_sampler, 
                                                    t, split, metric=metric, num_nodes=num_nodes,max_t=max_t, do_full_mrr=do_full_mrr, neg_iterator=neg_iterator, bruteforce_test=bruteforce_test, dst_candidates=dst_candidates)
        tot_filtered_val += num_filtered_val
        for metric, value in new_metrics.items():
            rankings[VALID][metric].extend(value)
        
        update_sequences(val_edges_by_time, tracker, t)


    test_times = sorted(test_edges_by_time.keys())
    for t in tqdm(test_times, desc="Testing"):
        split = get_split(t, timesplits)
        assert split == TEST
        new_metrics, num_filtered_test = evaluate_sequences_benchtemp(tracker, test_edges_by_time, # evaluator, neg_sampler, 
                                                    t, split, metric=metric, num_nodes=num_nodes,max_t=max_t, do_full_mrr=do_full_mrr, neg_iterator=neg_iterator, bruteforce_test=bruteforce_test, dst_candidates=dst_candidates)
        tot_filtered_test += num_filtered_test
        for metric, value in new_metrics.items():
            rankings[TEST][metric].extend(value)
        update_sequences(test_edges_by_time, tracker, t)
    print("Total number of filtered edges in validation set is ", tot_filtered_val)
    print("Total number of filtered edges in test set is ", tot_filtered_test)
    return tracker, rankings, full_rankings

def split_bt(full_data, split : List[float]):
    num_edges = len(full_data.sources)
    train_start_idx = 0
    train_end_idx = int(num_edges * split[0])

    current_t = full_data.timestamps[train_end_idx]

    while full_data.timestamps[train_end_idx] == current_t:
        train_end_idx += 1
    

    val_start_idx = train_end_idx
    
    val_end_idx = int(len(full_data.sources)*(split[0]+split[1]))

    current_t = full_data.timestamps[val_end_idx]
    val_initial_end_idx = val_end_idx
    while full_data.timestamps[val_end_idx] == current_t:
        val_end_idx += 1
    

    # breakpoint()
    test_start_idx = val_end_idx
    test_end_idx = None
    train_mask = torch.zeros(len(full_data.sources), dtype=torch.bool)
    val_mask = torch.zeros(len(full_data.sources), dtype=torch.bool)
    test_mask = torch.zeros(len(full_data.sources), dtype=torch.bool)

    train_slice = slice(train_start_idx, train_end_idx)
    val_slice = slice(val_start_idx, val_end_idx)
    test_slice = slice(test_start_idx, test_end_idx)
    
    
    train_mask[train_slice] = True
    val_mask[val_slice] = True
    test_mask[test_slice] = True


    return train_mask, val_mask, test_mask
    



def load_bt_data(dataset_name = "mooc"):
    data = bt.lp.DataLoader(dataset_path=os.path.join(BENCHTEMP_DATA_FOLDER, dataset_name)+"/", dataset_name=dataset_name)


    node_features, edge_features, full_data, train_da, transductive_val_data, transductive_test_data, new_node_val_data, new_node_test_data, new_old_node_val_data, new_old_node_test_data, new_new_node_val_data, new_new_node_test_data, unseen_nodes_num = data.load()
    # breakpoint()
    train_data = train_da
    train_d = TemporalData(
        t=LongTensor(train_data.timestamps),
        src=LongTensor(train_data.sources),
        dst=LongTensor(train_data.destinations),
        msg=Tensor(edge_features[1:])
    )

    val_d = TemporalData(
        t=LongTensor(transductive_val_data.timestamps),
        src=LongTensor(transductive_val_data.sources),
        dst=LongTensor(transductive_val_data.destinations),
        msg=Tensor(edge_features[1:])
    )

    test_d = TemporalData(
        t=LongTensor(transductive_test_data.timestamps),
        src=LongTensor(transductive_test_data.sources),
        dst=LongTensor(transductive_test_data.destinations),
        msg=Tensor(edge_features[1:])
    )

    data = TemporalData(
        t=torch.cat([train_d.t, val_d.t, test_d.t]),
        src=torch.cat([train_d.src, val_d.src, test_d.src]),
        dst=torch.cat([train_d.dst, val_d.dst, test_d.dst]),
        msg=torch.cat([train_d.msg, val_d.msg, test_d.msg])
    )

    marks = {VALID : defaultdict(set), TEST : defaultdict(set)}
    for split, setting, split_data in zip(
        [VALID, TEST, VALID, TEST, VALID, TEST], [INDUCTIVE, INDUCTIVE, INDUCTIVE_NEW_OLD, INDUCTIVE_NEW_OLD, INDUCTIVE_NEW_NEW, INDUCTIVE_NEW_NEW], [new_node_val_data, new_node_test_data, new_old_node_val_data, new_old_node_test_data, new_new_node_val_data, new_new_node_test_data]
        ):
        for src, dst, time in zip(split_data.sources, split_data.destinations, split_data.timestamps):
            marks[split][setting].add((src.item(), dst.item(), time.item()))

    return data, train_d, val_d, test_d, marks

def is_bipartite(data):
    sources = set(data.src.unique().numpy().tolist())
    destinations = set(data.dst.unique().numpy().tolist())
    return len(sources.intersection(destinations)) == 0

def convert_metrics_to_results(rankings, results, data, split : str):
    results[split] = {}
    pbar = tqdm(rankings[split], desc="Converting metrics")
    for metric in pbar:
        pbar.set_description(f"Converting {metric}")
        if not isinstance(rankings[split][metric], float) and not isinstance(rankings[split][metric], int):
            assert len(rankings[split][metric]) == len(data.t)
        
        if metric == TIME_RANKS:
            if isinstance(rankings[split][metric][0], Iterable):
                results[split][f"{metric}_value_counts"] = {i : pd.Series([submet[i] for submet in rankings[split][metric]]).value_counts().sort_index().to_dict() for i in range(int(len(rankings[split][metric][0])))}
            else:
                results[split][f"{metric}_value_counts"] = pd.Series(rankings[split][metric]).value_counts().sort_index().to_dict()

        if MRR in metric.lower():
            results[split][metric] = np.mean(rankings[split][metric])
        elif any(prob == metric for prob in [POSITIVE_PROBS, NEGATIVE_PROBS]):
            continue
        elif isinstance(rankings[split][metric], float) or isinstance(rankings[split][metric], int):
            results[split][metric] = rankings[split][metric]
        else:

            if isinstance(rankings[split][metric][0], Iterable):
                results[split][metric] = {i : pd.Series([submet[i] for submet in rankings[split][metric]]).describe(percentiles=PERCENTILES).to_dict() for i in range(int(len(rankings[split][metric][0])))}
            else:
                results[split][metric] = pd.Series(rankings[split][metric]).describe(percentiles=PERCENTILES).to_dict()

def ensure_correct_dataset_conversion(train_edges_by_time, val_edges_by_time, test_edges_by_time, train_data, val_data, test_data):
    assert len(train_edges_by_time) == len(train_data.t.unique())
    assert len(val_edges_by_time) == len(val_data.t.unique())
    assert len(test_edges_by_time) == len(test_data.t.unique())
    assert sum([len(edges) for edges in train_edges_by_time.values()]) == len(train_data.t)
    assert sum([len(edges) for edges in val_edges_by_time.values()]) == len(val_data.t)
    assert sum([len(edges) for edges in test_edges_by_time.values()]) == len(test_data.t)

