import numpy as np

from relnet.agent.pytorch_agent import PyTorchAgent


class PredictMeanAgent(PyTorchAgent):
    algorithm_name = "predict_mean"
    is_deterministic = True
    is_trainable = True

    def train(self, train_g_list, validation_g_list, max_steps, **kwargs):
        training_set_mean = np.array(self.graph_ds.get_gts_for_hashes(train_g_list)).mean()
        model_path = self.get_model_path(model_suffix=None, init_dir=True)
        with open(model_path, 'w') as fh:
            fh.write(str(training_set_mean))


    def predict(self, g_list, **kwargs):
        model_path = self.get_model_path(model_suffix=None, init_dir=True)
        with open(model_path, 'r') as fh:
            training_set_mean = float(fh.readline().strip())
        return np.ones(len(g_list)) * training_set_mean

    def finalize(self):
        pass


class PredictMedianAgent(PyTorchAgent):
    algorithm_name = "predict_median"
    is_deterministic = True
    is_trainable = True

    def train(self, train_g_list, validation_g_list, max_steps, **kwargs):
        training_set_median = np.median(np.array(self.graph_ds.get_gts_for_hashes(train_g_list)))
        model_path = self.get_model_path(model_suffix=None, init_dir=True)
        with open(model_path, 'w') as fh:
            fh.write(str(training_set_median))

    def predict(self, g_list, **kwargs):
        model_path = self.get_model_path(model_suffix=None, init_dir=True)
        with open(model_path, 'r') as fh:
            training_set_median = float(fh.readline().strip())
        return np.ones(len(g_list)) * training_set_median

    def finalize(self):
        pass
