import numpy as np


class QuadraticProblem:
    def __init__(self, data_stream, learning_rate, ini_theta = None, A = None):
        self.data_stream = data_stream
        self.dim = data_stream.dim
        self.learning_rate = learning_rate
        assert self.dim>=2

        self.optima = np.array([0.5] * self.dim, dtype=np.float32)
        if ini_theta is not None:
            self.theta = ini_theta
        else:
            self.theta = np.array([0.0] * self.dim, dtype=np.float32)

        if A is not None:
            self.A = A
        else:
            M = np.random.normal(size=(self.dim, self.dim))
            A = np.diagflat( np.random.uniform(size=self.dim )  )
            A[0,0] = 0.0
            if self.dim >= 3:
                A[1,1] = 0.0
            self.A =  np.matmul(  np.matmul(M.transpose(), A) , M )

        self.theta_hist = [self.theta]
        self.loss_hist = [self._loss( np.mean(self.theta_hist) )]

    def _loss(self, theta):
        loss = (theta - self.optima).dot(self.A).dot(theta - self.optima)
        return loss

    def step(self):
        data = self.data_stream.pop()
        if data is None:
            return None
        # SGD update over mini-batch
        self.theta -= self.learning_rate * (self.A + self.A.transpose()).dot(  self.theta - np.mean(data , axis=0))

        self.theta_hist.append(self.theta)
        self.loss_hist.append(self._loss(self.theta))
