#!/usr/bin/env python3
"""
Created on 18:29, May. 17th, 2023

@author: Anonymous
"""
import tensorflow as tf
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.pardir)
from utils import DotDict

__all__ = [
    "conv_net_params",
]

class conv_net_params(DotDict):
    """
    This contains one single object that generates a dictionary of parameters,
    which is provided to `conv_net` on initialization.
    """
    # Initialize macro parameter.
    _precision = "float32"

    def __init__(self, dataset="eeg_anonymous"):
        """
        Initialize `conv_net_params`.
        """
        ## First call super class init function to set up `DotDict`
        ## style object and inherit it's functionality.
        super(conv_net_params, self).__init__()

        ## Generate all parameters hierarchically.
        # -- Model parameters
        self.model = conv_net_params._gen_model_params(dataset)
        # -- Train parameters
        self.train = conv_net_params._gen_train_params(dataset)

        ## Do init iteration.
        conv_net_params.iteration(self, 0)

    """
    update funcs
    """
    # def iteration func
    def iteration(self, iteration):
        """
        Update parameters at every backpropagation iteration/gradient update.
        """
        ## -- Train parameters
        # Calculate current learning rate of pretrain part.
        self.train.lr_pretrain_i = self.train.lr_pretrain
        # Calculate current learning rate of finetune part.
        self.train.lr_finetune_i = self.train.lr_finetune

    """
    generate funcs
    """
    ## def _gen_model_* funcs
    # def _gen_model_params func
    @staticmethod
    def _gen_model_params(dataset):
        """
        Generate model parameters.
        """
        # Initialize `model_params`.
        model_params = DotDict()

        ## -- Normal parameters
        # The type of dataset.
        model_params.dataset = dataset
        # The mode of contrastive prediction.
        model_params.contra_pred_mode = ["max_z", "prob_z", "max_y", "prob_y"][-1]
        # The ratio of contrastive prediction.
        model_params.contra_pred_ratio = 0.
        # The ratio of contrastive loss.
        model_params.contra_loss_ratio = 0.
        # The scale factor of contrastive loss.
        model_params.contra_loss_scale = 1.
        # The number of models in ensemble.
        model_params.n_models = 10
        # The type of model ensemble.
        model_params.ensemble_type = ["min_loss", "average"][-1]
        if model_params.ensemble_type in ["min_loss",]:
            print("WARNING: The ensemble type is {}.".format(model_params.ensemble_type))
        ## -- Dataset-specific parameters
        # Normal parameters related to eeg_anonymous dataset.
        if model_params.dataset == "eeg_anonymous":
            # The number of input channels.
            model_params.n_channels = 55
            # The size of output labels.
            model_params.n_labels = 15
        # Normal parameters related to other dataset.
        else:
            # The number of input channels.
            model_params.n_channels = 64
            # The size of output labels.
            model_params.n_labels = 15
        # The dimension of contrastive layer.
        model_params.d_contra = 128
        ## -- Subject parameters
        model_params.subject = conv_net_params._gen_model_subject_params(model_params)
        ## -- CNN parameters
        model_params.cnn = conv_net_params._gen_model_cnn_params(model_params)

        # Return the final `model_params`.
        return model_params

    # def _gen_model_subject_params func
    @staticmethod
    def _gen_model_subject_params(model_params):
        """
        Generate model.subject parameters.
        """
        # Initialize `model_subject_params`.
        model_subject_params = DotDict()

        ## -- Dataset-specific parameters
        # Normal parameters related to eeg_anonymous dataset.
        if model_params.dataset == "eeg_anonymous":
            # The parameters related to conv block.
            model_subject_params.cnn = DotDict({
                # The number of filters.
                "n_filters": [256,],
                # The dropout ratio.
                "dropout_ratio": 0.3,
            })
            # The parameters related to hidden block.
            model_subject_params.hidden = DotDict({
                # The number of hidden units.
                "n_hiddens": [256, 128],
            })
            # The parameters related to subject layer.
            model_subject_params.subject = DotDict({
                "n_channels_output": 256,
            })
        else:
            # The parameters related to conv block.
            model_subject_params.cnn = DotDict({
                # The number of filters.
                "n_filters": [256,],
                # The dropout ratio.
                "dropout_ratio": 0.3,
            })
            # The parameters related to hidden block.
            model_subject_params.hidden = DotDict({
                # The number of hidden units.
                "n_hiddens": [256, 128],
            })
            # The parameters related to subject layer.
            model_subject_params.subject = DotDict({
                "n_channels_output": 256,
            })

        # Return the final `model_subject_params`.
        return model_subject_params

    # def _gen_model_cnn_params func
    @staticmethod
    def _gen_model_cnn_params(model_params):
        """
        Generate model.cnn parameters.
        """
        # Initialize `model_cnn_params`.
        model_cnn_params = DotDict()

        ## -- Dataset-specific parameters
        # Normal parameters related to eeg_anonymous dataset.
        if model_params.dataset == "eeg_anonymous":
            # The number of filters.
            model_cnn_params.n_filters = [128, 256, 512, 256, 64]
            # The size of filter kernel.
            model_cnn_params.kernel_size = [5, 7, 11, 9, 3]
            # The mode of padding.
            model_cnn_params.padding = ["same", "same", "same", "same", "same"]
            # The dilation rate to do convolution.
            model_cnn_params.dilation_rate = [1, 1, 2, 2, 2]
            # The mode of kernel initializer.
            model_cnn_params.kernel_initializer = ["glorot_uniform",\
                "glorot_uniform", "he_uniform", "he_uniform", "he_uniform"]
            # The size of max-pooling kernel.
            model_cnn_params.pool_size = 4
            # The dropout ratio.
            model_cnn_params.dropout_ratio = 0.3
        else:
            # The number of filters.
            model_cnn_params.n_filters = [128, 256, 512, 256, 64]
            # The size of filter kernel.
            model_cnn_params.kernel_size = [5, 7, 11, 9, 3]
            # The mode of padding.
            model_cnn_params.padding = ["same", "same", "same", "same", "same"]
            # The dilation rate to do convolution.
            model_cnn_params.dilation_rate = [1, 1, 2, 2, 2]
            # The mode of kernel initializer.
            model_cnn_params.kernel_initializer = ["glorot_uniform",\
                "glorot_uniform", "he_uniform", "he_uniform", "he_uniform"]
            # The size of max-pooling kernel.
            model_cnn_params.pool_size = 4
            # The dropout ratio.
            model_cnn_params.dropout_ratio = 0.3

        # Return the final `model_cnn_params`.
        return model_cnn_params

    ## def _gen_train_* funcs
    # def _gen_train_params func
    @staticmethod
    def _gen_train_params(dataset):
        """
        Generate training parameters.
        """
        # Initialize `train_params`.
        train_params = DotDict()

        ## -- Normal parameters
        # The type of dataset.
        train_params.dataset = dataset
        # The modality to evaluate.
        train_params.modality = "N2/3"
        # Precision parameter.
        train_params.precision = getattr(tf, conv_net_params._precision)\
            if hasattr(tf, conv_net_params._precision) else tf.float32
        # Whether use graph mode or eager mode.
        train_params.use_graph_mode = True
        # The ratio of train dataset. The rest is test dataset.
        train_params.train_ratio = 0.8
        # Size of buffer used in shuffle.
        train_params.buffer_size = int(1e4)
        # Peroid of iterations to save model.
        train_params.i_model = 5
        ## -- Dataset-specific parameters
        # Normal parameters related to eeg_anonymous dataset.
        if train_params.dataset == "eeg_anonymous":
            # Number of epochs used in training process.
            train_params.n_epochs = DotDict({"pretrain": 500, "finetune": 200,})
            # Number of batch size used in training process.
            train_params.batch_size = 256
            # The size of shuffle buffer.
            train_params.buffer_size = 1024
            # The learning rate of training process.
            train_params.lr_pretrain = 3e-4
            # The learning rate of training process.
            train_params.lr_finetune = 1e-4
        # Normal parameters related to other dataset.
        else:
            # Number of epochs used in training process.
            train_params.n_epochs = DotDict({"pretrain": 20, "finetune": 10,})
            # Number of batch size used in training process.
            train_params.batch_size = 16
            # The size of shuffle buffer.
            train_params.buffer_size = 128
            # The learning rate of pretrain part of training process.
            train_params.lr_pretrain = 3e-4
            # The learning rate of finetune part of training process.
            train_params.lr_finetune = 1e-4

        # Return the final `train_params`.
        return train_params

if __name__ == "__main__":
    # Instantiate `conv_net_params`.
    conv_net_params_inst = conv_net_params(dataset="eeg_anonymous")

