from collections import defaultdict
from functools import partial
import numpy as np


class Recorder(dict):
    """
    Record the result of a trial
    """

    """Implementation of perl's autovivification feature."""
    def __getitem__(self, item):
        try:
            return dict.__getitem__(self, item)
        except KeyError:
            value = self[item] = type(self)()
            return value

    # def __init__(self):
    #     self._scores = defaultdict(partial(defaultdict, dict))

    def record(self, experiment, permutation, task_id, n_samples, score):
        self[experiment][permutation][task_id][n_samples] = score

    def get_score(self, experiment, permutation, task_id, n_samples):
        return self[experiment][permutation][task_id][n_samples]

    def add(self, recorder):
        experiment_offset = max(self.keys()) + 1
        for experiment, a in recorder.items():
            for permutation, b in a.items():
                for task, c in b.items():
                    for samples, score in c.items():
                        self.record(experiment + experiment_offset, permutation, task, samples, score)


    def record2(self, experiment, task_id, n_samples, score):
        self[experiment][task_id][n_samples] = score

    def get_plot_data(self, threshold):
        data = list()
        for experiment, a in self.items():
            data1 = list()
            for permutation, b in a.items():
                data2 = list()
                for task, c in b.items():
                    samples = self._get_samples(c, threshold)
                    data2.append(samples)
                data1.append(np.array(data2))
            data.append(np.array(data1))
        return np.array(data)

    def get_plot_data2(self, threshold):
        data = list()
        for experiment, a in self.items():
            data1 = list()
            for task, c in a.items():
                samples = self._get_samples(c, threshold)
                data1.append(samples)
            data.append(np.array(data1))
        return np.array(data)

    def _get_samples(self, A, threshold):
        n = None
        for samples, score in A.items():
            n = samples
            if score > threshold:
                return samples
        return n