import argparse
import logging
import sys
import json
import os
import traceback

import numpy as np
import matplotlib.pyplot as plt

from summarize_experiments import Trials, experiment_id_from_slug

logging.basicConfig(format='%(asctime)s [%(levelname)s] (%(name)s):  %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO)
logger = logging.getLogger(__name__)

def default_getcolor(trials):
    default_color_dict = {
        'greedy_coreset': 'green',
        'batchbald': 'blue',
        'bald': 'cyan',
        'entropy': 'orange',
        'random': 'black',
    }
    config = trials.get_config()
    al_score_type = config['run_config'].get('al_config', dict()).get('score_method', None)
    assert al_score_type is not None

    return default_color_dict[al_score_type]

def default_getlabel(trials):
    config = trials.get_config()
    al_score_type = config['run_config'].get('al_config', dict()).get('score_method', None)
    assert al_score_type is not None

    return al_score_type

def plot_slug_record(experiment_slug, finder, tolerate_errors=False, color_fn=None, label_fn=None, **kwargs):
    if color_fn is None:
        color_fn = default_getcolor

    if label_fn is None:
        label_fn = default_getlabel

    experiment_id = experiment_id_from_slug(experiment_slug)

    experiment_dirs = finder.experiment_dirs_by_slug(experiment_slug)
    if len(experiment_dirs) == 0:
        logger.error("Failed to find experiment dirs for slug {}".format(experiment_slug))
        return

    trials = Trials(experiment_id, experiment_dirs, tolerate_errors=tolerate_errors)

    curvedata = trials.get_curves('bysize_test_acc')
    if len(curvedata) == 0:
        print('missing curvedata for {}'.format(experiment_slug))
        return
    xs = [x[0] for x in curvedata[0]]
    all_ys = np.array([[x[1] for x in xx] for xx in curvedata])
    mean_ys = all_ys.mean(axis=0)
    stderr_ys = all_ys.std(axis=0) / np.sqrt(all_ys.shape[0])

    _kwargs = {
        'capsize': 3,
    }
    _kwargs.update(kwargs)

    plt.errorbar(xs, mean_ys, yerr=stderr_ys, color=color_fn(trials), label=label_fn(trials), **_kwargs)

