# ----------------------------------------------------------------
# Copyright (c) 2020 Matthias Karlbauer, Sebastian Otte
# ----------------------------------------------------------------

__author__ = "Matthias Karlbauer, Sebastian Otte"


import torch

class ActiveTuning():
    """


    """

    def __init__(self, model, initial_model_state, initial_model_output,
                 opt_accessor, tuning_length=1, tuning_cycles=1, tf_rate=0.0,
                 eta=0.1, beta1=0.9, beta2=0.999, epsilon=1e-8,
                 bias_correction=False):

        self._model = model
        self._criterion = torch.nn.MSELoss()
        
        self._tuning_length = tuning_length
        self._tuning_cycles = tuning_cycles
        self._tf_rate = tf_rate
        self._eta = eta
        self._beta1 = beta1
        self._beta2 = beta2
        self._epsilon = epsilon

        self._bias_correction = bias_correction
        self._opt_accessor = opt_accessor

        # Get an example tensor to determine the required shape for first and
        # second order momentum tensors of ADAM
        example_tensor = self._opt_accessor(initial_model_output,
                                            initial_model_state)
        self._m = torch.zeros_like(example_tensor)
        self._v = torch.zeros_like(example_tensor)
        self._cycles_ctr = 0
        
        # Buffers to store the histories of states, outputs and observations
        self._model_states = []
        self._model_outputs = []
        self._observations = []

        # First (left-most) seed structures
        self._model_states.append(initial_model_state)
        self._model_outputs.append(initial_model_output)

        # Create empty entries for tuning_length
        for i in range(tuning_length):
            self._model_states.append(None)
            self._model_outputs.append(None)


    @property
    def model(self):
        return self._model

    @property
    def tuning_length(self):
        return self._tuning_length

    @property
    def tuning_cycles(self):
        return self._tuning_cycles

    @tuning_cycles.setter
    def tuning_cycles(self, value):
        self._tuning_cycles = value

    @property
    def eta(self):
        return self._eta

    @eta.setter
    def eta(self, value):
        self._eta = value
        
    @property
    def beta1(self):
        return self._beta1

    @beta1.setter
    def beta1(self, value):
        self._beta1 = value
    
    @property
    def beta2(self):
        return self._beta2

    @beta2.setter
    def beta2(self, value):
        self._beta2 = value
           
    @property
    def epsilon(self):
        return self._epsilon

    @epsilon.setter
    def epsilon(self, value):
        self._epsilon = value

    @property
    def bias_correction(self):
        return self._bias_correction

    @bias_correction.setter
    def bias_correction(self, value):
        self._bias_correction = value

    def tune(self, observation_t, reset_optimizer=False):
        
        if reset_optimizer :
            self._m = torch.zeros_like(self._m)
            self._v = torch.zeros_like(self._v)
            self._cycles_ctr = 0

        # Shift observations by one, if necessary.
        self._observations.append(observation_t)
        self._observations = self._observations[-(self._tuning_length + 1):]
        observations_length = len(self._observations)
        
        # Perform active tuning optimization cycles.
        for cycle in range(self._tuning_cycles):
        
            out_t = self._model_outputs[0]
            state_t = self._model_states[0]

            opt_tensor = self._opt_accessor(out_t, state_t)
            opt_tensor.requires_grad_()

            # Forward pass over tuning_length steps
            for at_t in range(self._tuning_length):

                in_t = out_t

                # Optionally, apply a fraction of teacher forcing
                if self._tf_rate > 0:
                    # Determine index for the required observations
                    obs_t = at_t - ((self._tuning_length + 1) \
                                - observations_length)
                    if obs_t >= 0:
                        # Compute portions of teacher forcing and previous
                        # output for the current input
                        in_t = (1.0 - self._tf_rate) * out_t \
                                   + self._tf_rate \
                                       * self._observations[obs_t].detach()
                
                # Forward pass
                out_t, state_t = self._model.forward(in_t, state_t)

                # Update buffers
                self._model_outputs[at_t + 1] = out_t
                self._model_states[at_t + 1] = state_t

            # Compute loss and error
            loss_length = min(observations_length, self._tuning_length)
            loss_outputs = self._model_outputs[-loss_length:]
            loss_observations = self._observations[-loss_length:]

            mse = self._criterion(torch.stack(loss_outputs),
                                  torch.stack(loss_observations))
            #mse = torch.mean(torch.pow(loss_outputs - loss_observations, 2))
            
            # Backward pass
            mse.backward()

            # Perform optimization
            with torch.no_grad():

                # Get gradients (based on MSE)
                g = opt_tensor.grad

                # ADAM update
                self._m = self._beta1 * self._m + (1 - self._beta1) * g
                
                self._v = self._beta2 * self._v + (1 - self._beta2) * (g * g)

                m_tmp = self._m
                v_tmp = self._v

                # Bias correction
                if self._bias_correction:
                    m_tmp = self._m \
                                / (1.0 - self._beta1 ** (self._cycles_ctr + 1))
                    v_tmp = self._v \
                                / (1.0 - self._beta2 ** (self._cycles_ctr + 1))

                update = (-self._eta / (torch.sqrt(v_tmp) + self._epsilon)) \
                         * m_tmp
                
                # Tensor update by applying the ADAM gradients
                opt_tensor.data = (opt_tensor + update).clone().data
                opt_tensor.grad.data.zero_()
                
                self._cycles_ctr += 1

        # Latent state/output has been optimized; this optimized state/output
        # is now propagated once more in forward direction in order to generate
        # the final output and state to be returned.
        with torch.no_grad():
            
            out_t = self._model_outputs[0]
            state_t = self._model_states[0]

            # Forward pass over tuning_length steps
            for at_t in range(self._tuning_length):
                
                in_t = out_t

                # Optionally, apply a fraction of teacher forcing
                if self._tf_rate > 0:
                    # Determine index for the required observations
                    obs_t = at_t - ((self._tuning_length + 1) \
                                - observations_length)
                    if obs_t >= 0:
                        # Compute portions of teacher forcing and previous
                        # output for the current input
                        in_t = (1.0 - self._tf_rate) * out_t \
                                   + self._tf_rate \
                                       * self._observations[obs_t].detach()

                # Forward pass
                out_t, state_t = self._model.forward(in_t, state_t)

                # Buffer update
                self._model_outputs[at_t + 1] = out_t
                self._model_states[at_t + 1] = state_t
            
            # Final buffer update (add new entry from right side and drop the
            # left-most entry)
            self._model_outputs.append(None)
            self._model_outputs = self._model_outputs[1:]
            
            self._model_states.append(None)
            self._model_states = self._model_states[1:]

            return out_t, state_t
