import json
import logging
import os


class Result:

    KEY_PARAMS = "params"
    KEY_EXPRTS = "experiments"
    KEY_LABELS = "labels"
    KEY_RESULT = "result"

    def __init__(self, params, save_path):
        self._params = params
        self._save_path = save_path
        self._data = None

    def __enter__(self):
        assert not os.path.isfile(self._save_path)
        self._data = {
            Result.KEY_PARAMS: self._params,
            Result.KEY_EXPRTS: {
                Result.KEY_LABELS: [],
                Result.KEY_RESULT: {}
            }
        }
        return self

    def __exit__(self, *args):
        dpath = os.path.dirname(self._save_path)
        if not os.path.isdir(dpath):
            os.makedirs(dpath)

        with open(self._save_path, "w") as f:
            json.dump(self._data, f, indent=4, sort_keys=True)
        logging.info("Updated results: {}".format(self._save_path))

    def add_entry(self, datapool, model, acquiref, expt_repeat_id, score):
        labelled = datapool.count_labelled()
        i, result = self._index_labelled(labelled)

        model_name = model.get_descriptive_name()
        acquire_name = acquiref.get_descriptive_name()
        self._insert_result(result, model_name, acquire_name, expt_repeat_id, i, score)

        logging.info("Test score for {} training labels: {:.4f}".format(
            labelled, score
        ))

    # === PROTECTED ===

    def _index_labelled(self, labelled):
        experiments = self._data[Result.KEY_EXPRTS]
        labels = experiments[Result.KEY_LABELS]
        result = experiments[Result.KEY_RESULT]
        try:
            return labels.index(labelled), result
        except ValueError:
            assert len(labels) == 0 or labels[-1] < labelled
            labels.append(labelled)
            return len(labels)-1, result

    def _insert_result(self, result, model_name, acquire_name, expt_repeat_id, i, score):
        if model_name not in result:
            result[model_name] = {}

        model_result = result[model_name]

        if acquire_name not in model_result:
            model_result[acquire_name] = []

        repeats = model_result[acquire_name]
        if expt_repeat_id == len(repeats):
            repeats.append([])

        assert expt_repeat_id < len(repeats)
        result_chain = repeats[expt_repeat_id]
        assert i == len(result_chain)
        result_chain.append(score)