from .trackers import __all__ as TRACKER_CHOICES

TRAIN = "train"
VALID = "val"
TEST = "test"


BASEFOLDER = None
MRR = "mrr"

BENCHTEMP_DATA_FOLDER = None
assert BENCHTEMP_DATA_FOLDER is not None and BASEFOLDER is not None, "Please set the BASEFOLDER and BENCHTEMP_DATA_FOLDER variables in constants.py"
BIPARTITE_DATASETS = [
    "tgbl-wiki",
    "tgbl-subreddit",
    "tgbl-lastfm", 
    "tgbl-review"
]
PERCENTILES = [
    0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 
    0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2, 
    0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 
    0.31, 0.32, 0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39, 0.4, 
    0.41, 0.42, 0.43, 0.44, 0.45, 0.46, 0.47, 0.48, 0.49, 0.5, 
    0.51, 0.52, 0.53, 0.54, 0.55, 0.56, 0.57, 0.58, 0.59, 0.6, 
    0.61, 0.62, 0.63, 0.64, 0.65, 0.66, 0.67, 0.68, 0.69, 0.7, 
    0.71, 0.72, 0.73, 0.74, 0.75, 0.76, 0.77, 0.78, 0.79, 0.8, 
    0.81, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87, 0.88, 0.89, 0.9, 
    0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.0
]


FULL_MRR = "Full MRR"

RANK_OPT, RANK_PESS = "Rank Opt", "Rank Pess"
RANK_DIFF = "Rank diff"

TIME_RANKS = "Time Ranks"
TIME_RANKS_NORMALIZED = "Time Ranks Normalized"
TIME_DELTAS = "Time Deltas"
ACCURACY = "Accuracy"
ROC_AUC = "ROC_AUC"
AP = "AP"
F1 = "f1-score"

AVG, STDEV = "mean", "std"
NUM_SAMPLES = "n"

BENCHTEMP_DATASETS=[
    'UNvote', 
    'enron', 
    'uci', 
    'UNtrade', 
    'USLegis', 
    'SocialEvo', 
    'lastfm', 
    'YoutubeRedditLarge', 
    'CollegeMsg', 
    'mooc', 
    'reddit', 'Contacts', 'CanParl', 'TaobaoLarge', 'DGraphFin', 'Flights', 'wikipedia', 'YoutubeRedditSmall', 'taobao']

INDUCTIVE = "inductive"
INDUCTIVE_NEW_OLD, INDUCTIVE_NEW_NEW = "inductive_new_old", "inductive_new_new"
TRANSDUCTIVE = "transductive"
POSITIVE_PROBS, NEGATIVE_PROBS = "pos_probs", "neg_probs"

METRIC_COL = "Metric"
SUBMETRIC = "Submetric"
DATASET_COL = "Dataset"
TRACKER_COL = "Tracker"
VALUE_COL = "Value"
SPLIT_COL = "Split"
RANK_COL = "Rank"
MODEL_COL = "Model"

TRACKER_NAMES = {
    "GlobalRecencyTracker" : "Global Recency", 
    "EdgeBankPredictor" : "EdgeBank", 
    "LocalRecencyTracker" : "Local Recency", 
    "NeighborBasedTracker" : "Neighborhood", 
    "NeighborBasedTrackerSimplified" : "Neighborhood Simplified", 
    "NeighborBasedTrackerHardcore" : "Neighborhood Hardcore", 
    "ReviewScorer" : "Review Scorer"
}

def _camel_case_split(s):
    # Copied from https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
    idx = list(map(str.isupper, s))
    # mark change of case
    l = [0]
    for (i, (x, y)) in enumerate(zip(idx, idx[1:])):
        if x and not y:  # "Ul"
            l.append(i)
        elif not x and y:  # "lU"
            l.append(i+1)
    l.append(len(s))
    # for "lUl", index of "U" will pop twice, have to filter that
    return [s[x:y] for x, y in zip(l, l[1:]) if x < y]

def _convert_tracker_name(tracker : str) -> str:
    if tracker in TRACKER_NAMES:
        return TRACKER_NAMES[tracker]
    return " ".join(_camel_case_split(tracker))

TRACKER_NAMES = {**TRACKER_NAMES, **{tracker : _convert_tracker_name(tracker) for tracker in TRACKER_CHOICES}}