import numpy as np
from scipy.stats import ortho_group

# helper functions
def sq_matrix_with_given_cond(n, cond):
    # result is always symmetric positive definite
    log_cond = np.log(cond)
    exp_vec = np.arange(-log_cond/4., log_cond * (n + 1)/(4 * (n - 1)), log_cond/(2.*(n-1)))
    exp_vec = exp_vec[:n]
    s = np.exp(exp_vec)
    S = np.diag(s)
    U, _ = np.linalg.qr((np.random.rand(n, n) - 5.) * 200)
    V, _ = np.linalg.qr((np.random.rand(n, n) - 5.) * 200)
    P = U.dot(S).dot(V.T)
    P = P.dot(P.T)

    return P

def matrix_with_given_cond(n, d, cond, symmetric=False):
    assert d >= 2
    assert n >= d
    P = ortho_group.rvs(dim=n)
    if symmetric:
        Q = P.T
    else:
        Q = ortho_group.rvs(dim=d)
    D = np.zeros((n, d))
    
    t = np.sqrt(cond)
    u = np.random.uniform(low=-1, high=1, size=d-2)
    np.insert(u, 0, -1)
    np.append(u, 1)
    np.fill_diagonal(D, np.float_power(t, u))
    
    A = P@D@Q
    return A



class RidgeRegression():
    def __init__(self, n, d, lambd=0.0, A=None, b=None, cond=None, consistent=False):
        self.name = "Ridge Regression"
        self.save_name = "ridgereg"
        self.n = n
        self.d = d
        assert self.d >= 2
        assert self.n >= self.d
        self.lambd = lambd
        self.A_cond = cond

        if A is not None:
            self.A = A
        else: # A == None
            if cond is not None and n == d: # cond = number, n == d
                self.A = sq_matrix_with_given_cond(n, cond)
            elif cond is not None and n != d: # cond = number, n != d
                self.A = matrix_with_given_cond(n, d, cond)
            else: # cond == None
                self.A = np.random.randn(n, d)
        
        if consistent:
            self.f_star = 0.0
            self.x_star = np.random.randn(self.d)
            self.b = self.A@self.x_star
        else:
            self.f_star = None
            self.x_star = None
            if b is not None:
                self.b = b
            else: # b == None
                self.b = np.random.randn(self.n)

        self.L_i = np.array([np.linalg.eig(np.outer(self.A[i], self.A[i]))[0].real.max() for i in range(self.n)]) + self.lambd
        self.L_max = self.L_i.max()
        self.L_avg = self.L_i.mean()
        self.L = np.linalg.eig(self.A.T@self.A)[0].real.max()/self.n + self.lambd
        self.mu = np.linalg.eig(self.A.T@self.A)[0].real.min()/self.n + self.lambd
        self.kappa = self.L/self.mu

    def func(self, x, batch):
        return np.linalg.norm(self.A[batch]@x-self.b[batch])**2/(2*batch.shape[0]) + (self.lambd/2)*np.linalg.norm(x)**2
    
    def grad(self, x, batch):
        return (self.A[batch].T@(self.A[batch]@x-self.b[batch]))/batch.shape[0] + self.lambd*x
