import matplotlib.pyplot as plt
import shutil
from datetime import datetime
from joblib import Memory
from functools import partial
import functools
import pandas as pd
import itertools
from datasets import Dataset, concatenate_datasets
import os



NUM_PROC = int(os.getenv('NSLOTS', default=7))

PLOT_LOC = "plots"
CACHE_LOC = "cache"

memory = Memory(CACHE_LOC, verbose=0)
cache = partial(memory.cache, ignore=["verbose"])

plt.rcParams.update({
    "text.usetex": True,
    'text.latex.preamble': r'\usepackage{amsfonts}',
})



def to_dataset(setup):
    df = pd.DataFrame(itertools.product(*setup.values()), columns=setup.keys())
    return Dataset.from_pandas(df)


labels = {
    "erdos_renyi": "Erdos-Renyi",
    "barabasi_albert": "Barabasi-Albert",
    "karate_club": "Karate Club",
    "communities": "Communities",
    "football": "Football",
    "facebook": "Facebook",
    "time": "Running time [s]",
    "num_samples": "Number of samples",
    "num_random_walks": "Number of random walks",
    "error": "Estimation error",
    "n": "Number of nodes $n$",
    "cutoff": "Cutoff Algorithm",
    "exact": "Exact Algorithm",
    "exact-sparse": "Exact Algorithm (Sparse)",
    "local-delete": "Meeting Time Algorithm",
    "sampling": "Sampling Algorithm",
    "relative_error": "Relative estimation error",
    "absolute_error": "Absolute estimation error",
    "hitting_time": "Hitting time",
    "deg_prod": "$\\deg(u) \cdot \\deg(v)$",
    "deg_ratio": "$\\deg(u) / \\deg(v)$",
    "pagerank_prod": "$\\mathrm{pagerank}(u) \cdot \\mathrm{pagerank}(v)$",
    "pagerank_ratio": "$\\mathrm{pagerank}(u) / \\mathrm{pagerank}(v)$",
    "true_hitting_time": "Hitting time",
}

def get_label(key, **kwargs):
    kwargs = {k: get_label(v) if isinstance(v, str) else v for k, v in kwargs.items()}
    label = labels.get(key, key)
    if len(kwargs) > 0: label = label.format(**kwargs)
    return label


def is_notebook():
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter


def save_figure(filename, filename_prefix=None):
    if is_notebook():
        plt.show()
    else:
        if filename_prefix is not None:
            filename = filename_prefix + "-" + filename
        basename = f"{PLOT_LOC}/{filename}"
        filename = f"{basename}.pdf"
        datestring = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        filename2 = f"{basename}_{datestring}.pdf"
        plt.savefig(filename)
        plt.close()
        shutil.copy(filename, filename2)

def savefig(f):
    def caller(*args, filename_prefix=None, **kwargs):
        f(*args, **kwargs, savefig=partial(save_figure,
            filename_prefix=filename_prefix))
    return caller

