#!/usr/bin/env python3
"""
Created on 02:07, Jul. 21st, 2022

@author: Anonymous
"""
import tensorflow as tf
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.pardir)
    from naive_cnn_params import naive_cnn_params
else:
    from .naive_cnn_params import naive_cnn_params

__all__ = [
    "cnn_ensemble_params",
]

class cnn_ensemble_params(naive_cnn_params):
    """
    This contains one single object that generates a dictionary of parameters,
    which is provided to `cnn_ensemble` on initialization.
    """
    # Internal macro parameter.
    _precision = "float32"

    def __init__(self, dataset="eeg_anonymous"):
        """
        Initialize `cnn_ensemble_params` object.
        """
        ## First call super class init function to set up `DotDict`
        ## style object and inherit it's functionality.
        super(cnn_ensemble_params, self).__init__(dataset)

        ## Update all parameters hierarchically.
        # -- Model parameters
        self._update_model_params()
        # -- Train parameters
        self._update_train_params()

        ## Do init iteration.
        cnn_ensemble_params.iteration(self, 0)

    """
    update funcs
    """
    # def iteration func
    def iteration(self, iteration):
        """
        Update parameters at every backpropagation iteration/gradient update.
        """
        ## -- Train parameters.
        # Calculate current learning rate.
        self.train.lr_i = self.train.lr

    # def _update_model_params func
    def _update_model_params(self):
        """
        Update model parameters.
        """
        # The number of model ensembles.
        self.model.n_models = 10
        # The type of ensemble.
        self.model.ensemble_type = ["min_loss", "average"][-1]
        if self.model.ensemble_type in ["min_loss",]:
            print("WARNING: The ensemble type is {}.".format(self.model.ensemble_type))

    # def _update_train_params func
    def _update_train_params(self):
        """
        Update train parameters.
        """
        # The learning rate of optimizer.
        self.train.lr = self.model.lr

if __name__ == "__main__":
    # Instantiate `cnn_ensemble_params`.
    cnn_ensemble_params_inst = cnn_ensemble_params(dataset="eeg_anonymous")

