#!/usr/bin/env python3
"""
Created on 16:47, Dec. 20th, 2022

@author: Anonymous
"""
import tensorflow as tf
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.pardir)
from utils import DotDict

__all__ = [
    "naive_cnn_params",
]

class naive_cnn_params(DotDict):
    """
    This contains one single object that generates a dictionary of parameters,
    which is provided to `naive_cnn` on initialization.
    """
    # Internal macro parameter.
    _precision = "float32"

    def __init__(self, dataset="meg_liu2019cell"):
        """
        Initialize `naive_cnn_params` object.
        """
        ## First call super class init function to set up `DotDict`
        ## style object and inherit it's functionality.
        super(naive_cnn_params, self).__init__()

        ## Generate all parameters hierarchically.
        # -- Model parameters
        self.model = naive_cnn_params._gen_model_params(dataset)
        # -- Training parameters
        self.train = naive_cnn_params._gen_train_params(dataset)

        ## Do init iteration.
        naive_cnn_params.iteration(self, 0)

    """
    update funcs
    """
    # def iteration func
    def iteration(self, iteration):
        """
        Update parameters at every backpropagation iteration/gradient update.
        """
        ## -- Model parameters.
        # Calculate current learning rate.
        self.model.lr_i = self.model.lr

    """
    generate funcs
    """
    ## def _gen_model_* funcs
    # def _gen_model_params func
    @staticmethod
    def _gen_model_params(dataset):
        """
        Generate model parameters.
        """
        # Initialize `model_params`.
        model_params = DotDict()

        ## -- Normal parameters
        # The type of dataset.
        model_params.dataset = dataset
        # Normal parameters related to meg_liu2019cell dataset.
        if model_params.dataset == "meg_liu2019cell":
            # The size of input channels.
            model_params.n_channels = 273
            # The size of output classes.
            model_params.n_labels = 8
            # The learning rate of optimizer.
            model_params.lr = 3e-4
        # Normal parameters related to eeg_anonymous dataset task session.
        elif model_params.dataset == "eeg_anonymous.task":
            # The size of input channels.
            model_params.n_channels = 55
            # The size of output classes.
            model_params.n_labels = 15
            # The learning rate of optimizer.
            model_params.lr = 3e-4
        # Normal parameters related to eeg_anonymous dataset tmr session.
        elif model_params.dataset == "eeg_anonymous":
            # The size of input channels.
            model_params.n_channels = 55
            # The size of output classes.
            model_params.n_labels = 15
            # The learning rate of optimizer.
            model_params.lr = 3e-4
        # Normal parameters related to meg_anonymous dataset.
        elif model_params.dataset == "meg_anonymous":
            # The size of input channels.
            model_params.n_channels = 204
            # The size of output classes.
            model_params.n_labels = 15
            # The learning rate of optimizer.
            model_params.lr = 3e-4
        # Normal parameters related to meg_lv2023cpnl dataset.
        elif model_params.dataset == "meg_lv2023cpnl":
            # The size of input channels.
            model_params.n_channels = 204
            # The size of output classes.
            model_params.n_labels = 12
            # The learning rate of optimizer.
            model_params.lr = 3e-4
        # Normal parameters related to seeg_he2023xuanwu dataset.
        elif model_params.dataset == "seeg_he2023xuanwu":
            # The size of input channels.
            model_params.n_channels = 128
            # The size of output classes.
            model_params.n_labels = 61
            # The learning rate of optimizer.
            model_params.lr = 3e-4
        # Normal parameters related to meg_hebart2023things dataset.
        elif model_params.dataset == "meg_hebart2023things":
            # The size of input channels.
            model_params.n_channels = 271
            # The size of output classes.
            model_params.n_labels = 15
            # The learning rate of optimizer.
            model_params.lr = 3e-4
        # Normal parameters related to eeg_palazzo2020decoding dataset.
        elif model_params.dataset == "eeg_palazzo2020decoding":
            # The size of input channels.
            model_params.n_channels = 128
            # The size of output classes.
            model_params.n_labels = 40
            # The learning rate of optimizer.
            model_params.lr = 3e-4
        # Normal parameters related to eeg_gifford2022large dataset.
        elif model_params.dataset == "eeg_gifford2022large":
            # The size of input channels.
            model_params.n_channels = 17
            # The size of output classes.
            model_params.n_labels = 200
            # The learning rate of optimizer.
            model_params.lr = 3e-4
        # Normal parameters related to other dataset.
        else:
            # The size of input channels.
            model_params.n_channels = 273
            # The size of output classes.
            model_params.n_labels = 8
            # The learning rate of optimizer.
            model_params.lr = 3e-4
        ## -- CNN parameters
        model_params.cnn = naive_cnn_params._gen_model_cnn_params(model_params)
        ## -- Fully connect parameters
        model_params.fc = naive_cnn_params._gen_model_fc_params(model_params)

        # Return the final `model_params`.
        return model_params

    # def _gen_model_cnn_params func
    @staticmethod
    def _gen_model_cnn_params(model_params):
        """
        Generate model.cnn parameters.
        """
        # Initialize `model_cnn_params`.
        model_cnn_params = DotDict()

        ## -- Normal parameters (related to Conv1d)
        # Normal parameters related to meg_liu2019cell dataset.
        if model_params.dataset == "meg_liu2019cell":
            # The dimension of input vector.
            model_cnn_params.d_input = model_params.n_channels
            # The number of filters of each CNN layer.
            model_cnn_params.n_filters = [256, 128]
            # The size of kernel of each CNN layer.
            model_cnn_params.d_kernel = [9, 11]
            # The length of stride of each CNN layer.
            model_cnn_params.strides = [1, 1]
            # The length of padding of each CNN layer.
            model_cnn_params.padding = ["same", "same"]
            # The dilation rate of each CNN layer.
            model_cnn_params.dilation_rate = [1, 2]
            ## -- Normal parameters (related to MaxPool1d)
            # The size of max pooling kernel of each CNN layer.
            model_cnn_params.d_pooling_kernel = 2
            ## -- Normal parameters (related to Dropout)
            # The dropout rate of dropout layer.
            model_cnn_params.dropout = 0.5
        # Normal parameters related to eeg_anonymous dataset.
        elif model_params.dataset == "eeg_anonymous":
            # The dimension of input vector.
            model_cnn_params.d_input = model_params.n_channels
            # The number of filters of each CNN layer.
            model_cnn_params.n_filters = [256, 128]
            # The size of kernel of each CNN layer.
            model_cnn_params.d_kernel = [9, 11]
            # The length of stride of each CNN layer.
            model_cnn_params.strides = [1, 1]
            # The length of padding of each CNN layer.
            model_cnn_params.padding = ["same", "same"]
            # The dilation rate of each CNN layer.
            model_cnn_params.dilation_rate = [1, 2]
            ## -- Normal parameters (related to MaxPool1d)
            # The size of max pooling kernel of each CNN layer.
            model_cnn_params.d_pooling_kernel = 2
            ## -- Normal parameters (related to Dropout)
            # The dropout rate of dropout layer.
            model_cnn_params.dropout = 0.5
        # Normal parameters related to meg_anonymous grad dataset.
        elif model_params.dataset == "meg_anonymous":
            # The dimension of input vector.
            model_cnn_params.d_input = model_params.n_channels
            # The number of filters of each CNN layer.
            model_cnn_params.n_filters = [256, 128]
            # The size of kernel of each CNN layer.
            model_cnn_params.d_kernel = [9, 11]
            # The length of stride of each CNN layer.
            model_cnn_params.strides = [1, 1]
            # The length of padding of each CNN layer.
            model_cnn_params.padding = ["same", "same"]
            # The dilation rate of each CNN layer.
            model_cnn_params.dilation_rate = [1, 2]
            ## -- Normal parameters (related to MaxPool1d)
            # The size of max pooling kernel of each CNN layer.
            model_cnn_params.d_pooling_kernel = 2
            ## -- Normal parameters (related to Dropout)
            # The dropout rate of dropout layer.
            model_cnn_params.dropout = 0.5
        # Normal parameters related to meg_anonymous mag dataset.
        elif model_params.dataset == "meg_anonymous.mag":
            # The dimension of input vector.
            model_cnn_params.d_input = model_params.n_channels
            # The number of filters of each CNN layer.
            model_cnn_params.n_filters = [256, 128]
            # The size of kernel of each CNN layer.
            model_cnn_params.d_kernel = [9, 11]
            # The length of stride of each CNN layer.
            model_cnn_params.strides = [1, 1]
            # The length of padding of each CNN layer.
            model_cnn_params.padding = ["same", "same"]
            # The dilation rate of each CNN layer.
            model_cnn_params.dilation_rate = [1, 2]
            ## -- Normal parameters (related to MaxPool1d)
            # The size of max pooling kernel of each CNN layer.
            model_cnn_params.d_pooling_kernel = 2
            ## -- Normal parameters (related to Dropout)
            # The dropout rate of dropout layer.
            model_cnn_params.dropout = 0.5
        # Normal parameters related to meg_anonymous eeg dataset.
        elif model_params.dataset == "meg_anonymous.eeg":
            # The dimension of input vector.
            model_cnn_params.d_input = model_params.n_channels
            # The number of filters of each CNN layer.
            model_cnn_params.n_filters = [256, 128]
            # The size of kernel of each CNN layer.
            model_cnn_params.d_kernel = [9, 11]
            # The length of stride of each CNN layer.
            model_cnn_params.strides = [1, 1]
            # The length of padding of each CNN layer.
            model_cnn_params.padding = ["same", "same"]
            # The dilation rate of each CNN layer.
            model_cnn_params.dilation_rate = [1, 2]
            ## -- Normal parameters (related to MaxPool1d)
            # The size of max pooling kernel of each CNN layer.
            model_cnn_params.d_pooling_kernel = 2
            ## -- Normal parameters (related to Dropout)
            # The dropout rate of dropout layer.
            model_cnn_params.dropout = 0.5
        # Normal parameters related to meg_lv2023cpnl dataset.
        elif model_params.dataset == "meg_lv2023cpnl":
            # The dimension of input vector.
            model_cnn_params.d_input = model_params.n_channels
            # The number of filters of each CNN layer.
            model_cnn_params.n_filters = [256, 128]
            # The size of kernel of each CNN layer.
            model_cnn_params.d_kernel = [9, 11]
            # The length of stride of each CNN layer.
            model_cnn_params.strides = [1, 1]
            # The length of padding of each CNN layer.
            model_cnn_params.padding = ["same", "same"]
            # The dilation rate of each CNN layer.
            model_cnn_params.dilation_rate = [1, 2]
            ## -- Normal parameters (related to MaxPool1d)
            # The size of max pooling kernel of each CNN layer.
            model_cnn_params.d_pooling_kernel = 2
            ## -- Normal parameters (related to Dropout)
            # The dropout rate of dropout layer.
            model_cnn_params.dropout = 0.5
        # Normal parameters related to seeg_he2023xuanwu dataset.
        elif model_params.dataset == "seeg_he2023xuanwu":
            # The dimension of input vector.
            model_cnn_params.d_input = model_params.n_channels
            # The number of filters of each CNN layer.
            model_cnn_params.n_filters = [256, 128]
            # The size of kernel of each CNN layer.
            model_cnn_params.d_kernel = [9, 11]
            # The length of stride of each CNN layer.
            model_cnn_params.strides = [1, 1]
            # The length of padding of each CNN layer.
            model_cnn_params.padding = ["same", "same"]
            # The dilation rate of each CNN layer.
            model_cnn_params.dilation_rate = [1, 2]
            ## -- Normal parameters (related to MaxPool1d)
            # The size of max pooling kernel of each CNN layer.
            model_cnn_params.d_pooling_kernel = 2
            ## -- Normal parameters (related to Dropout)
            # The dropout rate of dropout layer.
            model_cnn_params.dropout = 0.5
        # Normal parameters related to meg_hebart2023things dataset.
        elif model_params.dataset == "meg_hebart2023things":
            # The dimension of input vector.
            model_cnn_params.d_input = model_params.n_channels
            # The number of filters of each CNN layer.
            model_cnn_params.n_filters = [256, 128]
            # The size of kernel of each CNN layer.
            model_cnn_params.d_kernel = [9, 11]
            # The length of stride of each CNN layer.
            model_cnn_params.strides = [1, 1]
            # The length of padding of each CNN layer.
            model_cnn_params.padding = ["same", "same"]
            # The dilation rate of each CNN layer.
            model_cnn_params.dilation_rate = [1, 2]
            ## -- Normal parameters (related to MaxPool1d)
            # The size of max pooling kernel of each CNN layer.
            model_cnn_params.d_pooling_kernel = 2
            ## -- Normal parameters (related to Dropout)
            # The dropout rate of dropout layer.
            model_cnn_params.dropout = 0.5
        # Normal parameters related to eeg_palazzo2020decoding dataset.
        elif model_params.dataset == "eeg_palazzo2020decoding":
            # The dimension of input vector.
            model_cnn_params.d_input = model_params.n_channels
            # The number of filters of each CNN layer.
            model_cnn_params.n_filters = [256, 128]
            # The size of kernel of each CNN layer.
            model_cnn_params.d_kernel = [9, 11]
            # The length of stride of each CNN layer.
            model_cnn_params.strides = [1, 1]
            # The length of padding of each CNN layer.
            model_cnn_params.padding = ["same", "same"]
            # The dilation rate of each CNN layer.
            model_cnn_params.dilation_rate = [1, 2]
            ## -- Normal parameters (related to MaxPool1d)
            # The size of max pooling kernel of each CNN layer.
            model_cnn_params.d_pooling_kernel = 440
            ## -- Normal parameters (related to Dropout)
            # The dropout rate of dropout layer.
            model_cnn_params.dropout = 0.5
        # Normal parameters related to eeg_gifford2022large dataset.
        elif model_params.dataset == "eeg_gifford2022large":
            # The dimension of input vector.
            model_cnn_params.d_input = model_params.n_channels
            # The number of filters of each CNN layer.
            model_cnn_params.n_filters = [256, 128]
            # The size of kernel of each CNN layer.
            model_cnn_params.d_kernel = [9, 11]
            # The length of stride of each CNN layer.
            model_cnn_params.strides = [1, 1]
            # The length of padding of each CNN layer.
            model_cnn_params.padding = ["same", "same"]
            # The dilation rate of each CNN layer.
            model_cnn_params.dilation_rate = [1, 2]
            ## -- Normal parameters (related to MaxPool1d)
            # The size of max pooling kernel of each CNN layer.
            model_cnn_params.d_pooling_kernel = 2
            ## -- Normal parameters (related to Dropout)
            # The dropout rate of dropout layer.
            model_cnn_params.dropout = 0.5
        # Normal parameters related to other dataset.
        else:
            # The dimension of input vector.
            model_cnn_params.d_input = model_params.n_channels
            # The number of filters of each CNN layer.
            model_cnn_params.n_filters = [256, 128]
            # The size of kernel of each CNN layer.
            model_cnn_params.d_kernel = [9, 11]
            # The length of stride of each CNN layer.
            model_cnn_params.strides = [1, 1]
            # The length of padding of each CNN layer.
            model_cnn_params.padding = ["same", "same"]
            # The dilation rate of each CNN layer.
            model_cnn_params.dilation_rate = [1, 2]
            ## -- Normal parameters (related to MaxPool1d)
            # The size of max pooling kernel of each CNN layer.
            model_cnn_params.d_pooling_kernel = 2
            ## -- Normal parameters (related to Dropout)
            # The dropout rate of dropout layer.
            model_cnn_params.dropout = 0.5

        # Return the final `model_cnn_params`.
        return model_cnn_params

    # def _gen_model_fc_params func
    @staticmethod
    def _gen_model_fc_params(model_params):
        """
        Generate model.fc parameters.
        """
        # Initialize `model_fc_params`.
        model_fc_params = DotDict()

        ## -- Normal parameters
        # Normal parameters related to meg_liu2019cell dataset.
        if model_params.dataset == "meg_liu2019cell":
            # The dimensions of hidden layers.
            model_fc_params.d_hidden = []
            # The dropout rate of dropout layer.
            model_fc_params.dropout = 0.
        # Normal parameters related to eeg_anonymous dataset.
        elif model_params.dataset == "eeg_anonymous":
            # The dimensions of hidden layers.
            model_fc_params.d_hidden = [128,]
            # The dropout rate of dropout layer.
            model_fc_params.dropout = 0.
        # Normal parameters related to meg_anonymous dataset.
        elif model_params.dataset == "meg_anonymous":
            # The dimensions of hidden layers.
            model_fc_params.d_hidden = []
            # The dropout rate of dropout layer.
            model_fc_params.dropout = 0.
        # Normal parameters related to meg_lv2023cpnl dataset.
        elif model_params.dataset == "meg_lv2023cpnl":
            # The dimensions of hidden layers.
            model_fc_params.d_hidden = []
            # The dropout rate of dropout layer.
            model_fc_params.dropout = 0.
        # Normal parameters related to seeg_he2023xuanwu dataset.
        elif model_params.dataset == "seeg_he2023xuanwu":
            # The dimensions of hidden layers.
            model_fc_params.d_hidden = []
            # The dropout rate of dropout layer.
            model_fc_params.dropout = 0.
        # Normal parameters related to meg_hebart2023things dataset.
        elif model_params.dataset == "meg_hebart2023things":
            # The dimensions of hidden layers.
            model_fc_params.d_hidden = []
            # The dropout rate of dropout layer.
            model_fc_params.dropout = 0.
        # Normal parameters related to eeg_palazzo2020decoding dataset.
        elif model_params.dataset == "eeg_palazzo2020decoding":
            # The dimensions of hidden layers.
            model_fc_params.d_hidden = []
            # The dropout rate of dropout layer.
            model_fc_params.dropout = 0.
        # Normal parameters related to eeg_gifford2022large dataset.
        elif model_params.dataset == "eeg_gifford2022large":
            # The dimensions of hidden layers.
            model_fc_params.d_hidden = []
            # The dropout rate of dropout layer.
            model_fc_params.dropout = 0.
        # Normal parameters related to other dataset.
        else:
            # The dimensions of hidden layers.
            model_fc_params.d_hidden = []
            # The dropout rate of dropout layer.
            model_fc_params.dropout = 0.
        # The dimension of output vector.
        model_fc_params.d_output = model_params.n_labels

        # Return the final `model_fc_params`.
        return model_fc_params

    ## def _gen_train_* funcs
    # def _gen_train_params func
    @staticmethod
    def _gen_train_params(dataset):
        """
        Generate training parameters.
        """
        # Initialize `train_params`.
        train_params = DotDict()

        ## -- Normal parameters
        # The type of dataset.
        train_params.dataset = dataset
        # Precision parameter.
        train_params.precision = getattr(tf, naive_cnn_params._precision)\
            if hasattr(tf, naive_cnn_params._precision) else tf.float32
        # Whether use graph mode or eager mode.
        train_params.use_graph_mode = True
        # The ratio of train dataset. The rest is test dataset.
        train_params.train_ratio = 0.8
        # Size of buffer used in shuffle.
        train_params.buffer_size = int(1e4)
        ## -- Dataset-specific parameters
        # Normal parameters related to meg_liu2019cell dataset.
        if train_params.dataset == "meg_liu2019cell":
            # Number of epochs used in training process.
            train_params.n_epochs = 20
            # Number of batch size used in training process.
            train_params.batch_size = 16
        # Normal parameters related to eeg_anonymous dataset.
        elif train_params.dataset == "eeg_anonymous":
            # Number of epochs used in training process.
            train_params.n_epochs = 200
            # Number of batch size used in training process.
            train_params.batch_size = 256
        # Normal parameters related to meg_anonymous dataset.
        elif train_params.dataset == "meg_anonymous":
            # Number of epochs used in training process.
            train_params.n_epochs = 500
            # Number of batch size used in training process.
            train_params.batch_size = 256
        # Normal parameters related to meg_lv2023cpnl dataset.
        elif train_params.dataset == "meg_lv2023cpnl":
            # Number of epochs used in training process.
            train_params.n_epochs = 200
            # Number of batch size used in training process.
            train_params.batch_size = 256
        # Normal parameters related to seeg_he2023xuanwu dataset.
        elif train_params.dataset == "seeg_he2023xuanwu":
            # Number of epochs used in training process.
            train_params.n_epochs = 200
            # Number of batch size used in training process.
            train_params.batch_size = 256
        # Normal parameters related to meg_hebart2023things dataset.
        elif train_params.dataset == "meg_hebart2023things":
            # Number of epochs used in training process.
            train_params.n_epochs = 200
            # Number of batch size used in training process.
            train_params.batch_size = 256
        # Normal parameters related to eeg_palazzo2020decoding dataset.
        elif train_params.dataset == "eeg_palazzo2020decoding":
            # Number of epochs used in training process.
            train_params.n_epochs = 200
            # Number of batch size used in training process.
            train_params.batch_size = 256
        # Normal parameters related to eeg_gifford2022large dataset.
        elif train_params.dataset == "eeg_gifford2022large":
            # Number of epochs used in training process.
            train_params.n_epochs = 200
            # Number of batch size used in training process.
            train_params.batch_size = 256
        # Normal parameters related to other dataset.
        else:
            # Number of epochs used in training process.
            train_params.n_epochs = 200
            # Number of batch size used in training process.
            train_params.batch_size = 128

        # Return the final `train_params`.
        return train_params

if __name__ == "__main__":
    # Instantiate `naive_cnn_params`.
    naive_cnn_params_inst = naive_cnn_params(dataset="meg_liu2019cell")

