# -*- coding: utf-8 -*-
# @date: 20220429
import numpy as np

from utils import *

"""
An implementation of general first order optimization algorithm allowing delayed
first order information.

The optimizer implements general interface for

- Initialization and reset
- Iteration (with momentum) 
- Parameter setup
- Iteration extraction

"""


class mpOpt(object):

    def __init__(self, n_dim: int, gamma: numeric, momentum: numeric = 0.0) -> None:

        self.n = n_dim
        self._gamma = gamma
        self._beta = momentum
        self._x_old = np.zeros(self.n, dtype=numeric)
        self._x = np.zeros(self.n, dtype=numeric)
        self._y = np.zeros(self.n, dtype=numeric)

        return

    def initialize(self, x_0 : np.ndarray):

        if x_0.shape != self._x.shape:
            raise RuntimeError("Invalid initial point x_0")
        self._x = x_0.astype(numeric)
        self._x_old = self._x
        return

    def reset(self) -> None:
        self.n = 0
        self._gamma = 0.0
        self._beta = 0.0
        self._x *= 0.0
        self._y *= 0.0
        return

    def iterate(self, g: np.ndarray) -> None:
        raise NotImplementedError("Iteration is undefined")

    def sync_opt(self, epoch, A_data, b_data, alg):
        raise NotImplementedError("No synchronous optimizaiton block")

    def get_x(self) -> np.ndarray:
        return self._x.astype(numeric) \
            if self._x.dtype != numeric else self._x

