from experiments import BaseExperiment


class Exp_Extrap(BaseExperiment):

    def validation_step(self, epoch):
        results = self.compute_results_all_batches(self.dlval)
        self.logger.info(f"val_mse={results['mse']:.5f}")
        self.logger.info(f"val_mse_extrap={results['mse_extrap']:.5f}")
        self.logger.info(f"val_forward_time={results['forward_time']:.5f}")

        if results['val_loss'] != 0:
            return results['val_loss']
        else:
            return results['loss']

    def test_step(self):
        results = self.compute_results_all_batches(self.dltest)
       
        self.logger.info(f"test_mse={results['mse']:.5f}")

        self.logger.info(f"test_mse_extrap={results['mse_extrap']:.5f}")
        self.logger.info(f"test_forward_time={results['forward_time']:.5f}")

        if results['val_loss'] != 0:
            return results['val_loss']
        else:
            return results['loss']
