import torch as th
import configuration as cfg


class KernelParameters:
    """
    This class holds the parameters of the Kernel Network.
    """

    def __init__(self, pk_batches, device):

        #
        # System parameters
        self.device = device

        #
        # General network parameters
        self.seq_len = cfg.SEQ_LEN

        #
        # PK specific parameters
        self.pk_batches = pk_batches
        self.pk_neighbors = cfg.PK_NEIGHBORS

        # Input sizes (dimensions)
        self.pk_dyn_in_size = cfg.PK_DYN_IN_SIZE
        self.pk_lat_in_size = cfg.PK_LAT_IN_SIZE

        # Layer sizes (number of neurons per layer)
        self.pk_pre_layer_size = cfg.PK_PRE_LAYER_SIZE
        self.pk_num_lstm_cells = cfg.PK_NUM_LSTM_CELLS

        # Output sizes (dimensions)
        self.pk_dyn_out_size = cfg.PK_DYN_OUT_SIZE
        self.pk_lat_out_size = cfg.PK_LAT_OUT_SIZE

        #
        # Active Tuning parameters
        self.active_tuning = cfg.ACTIVE_TUNING
        if self.active_tuning:
            self.at_tf_rate = cfg.AT_TF_RATE

            self.at_length = cfg.AT_LENGTH
            self.at_cycles = cfg.AT_CYCLES
            self.at_eta = cfg.AT_ETA

            self.adam_beta1 = cfg.ADAM_BETA1
            self.adam_beta2 = cfg.ADAM_BETA2
            self.adam_epsilon = cfg.ADAM_EPSILON
            self.adam_reset_optimizer = cfg.ADAM_RESET_OPTIMIZER


class KernelTensors:
    """
    This class holds the tensors of the Kernel Network.
    """

    def __init__(self, params):
        self.params = params

        # Initialize the tensors by calling the reset method (this may not be
        # clean code style, yet it spares lots of lines :p)
        self.reset(self.params.pk_batches)

    def reset(self, pk_num):

        #
        # PK tensors

        # Inputs
        self.pk_dyn_in = th.zeros(size=(pk_num,
                                        self.params.pk_dyn_in_size),
                                  device=self.params.device)
        self.pk_lat_in = th.zeros(
            size=(pk_num,
                  self.params.pk_lat_in_size * self.params.pk_neighbors),
            device=self.params.device
        )

        # LSTM states
        self.pk_lstm_c = th.zeros(
            size=(pk_num, self.params.pk_num_lstm_cells),
            device=self.params.device,
            requires_grad=False
        )
        self.pk_lstm_h = th.zeros(
            size=(pk_num, self.params.pk_num_lstm_cells),
            device=self.params.device,
            requires_grad=False
        )

        # Outputs
        self.pk_dyn_out = th.zeros(size=(pk_num,
                                         self.params.pk_dyn_out_size),
                                   device=self.params.device)
        self.pk_lat_out = th.zeros(
            size=(pk_num,
                  self.params.pk_lat_out_size * self.params.pk_neighbors),
            device=self.params.device
        )

    def detach(self):
        self.pk_dyn_in = self.pk_dyn_in.detach()
        self.pk_lat_in = self.pk_lat_in.detach()

        # LSTM states
        self.pk_lstm_c = self.pk_lstm_c.detach()
        self.pk_lstm_h = self.pk_lstm_h.detach()

        # Outputs
        self.pk_dyn_out = self.pk_dyn_out.detach()
        self.pk_lat_out = self.pk_lat_out.detach()
