from distributed_pcg.utils import read_dataset
import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import sys
import cvxpy as cp



def create_data(n,d,l,seed):
    np.random.seed(seed)
    A = np.random.normal(size=(n,d))
    U,_,Vt = np.linalg.svd(A,full_matrices=False)
    alpha = np.diag([(0.99)**i for i in range(d)])
    A = U@(alpha)**0.5@Vt
    AtA = A.T@A 
    eff_dim = np.trace((AtA/n)@np.linalg.inv((AtA/n)+l*np.eye(d)))
    x  = np.random.normal(size=(d))
    y = A@x + 0.1*np.random.normal(n)
    return A, y, eff_dim 

class obj_func():
    def __init__(self,A,y,n,l,q):
        self.A = A 
        self.y = y 
        self.n = n
        self.l = l
        self.q = q
    def eval(self,x):
        return ((self.A@x-self.y)**2).sum()/(2*self.n)+(self.l/2)*(x**2).sum()
    def grad(self,x):
        return (1/self.n)*self.A.T@(self.A@x-self.y)+self.l*x 
    def hessian(self,x):
        return (self.A.T@self.A)/self.n+self.l*np.eye(self.A.shape[1])
    def hessian_inv(self,dl,method='sample'):
        q = self.q
        if method=='sample':
            n = self.A.shape[0]
            m = 4*(int(dl)+1)
            hessian_inv = 0
            for _ in range(q):
                S = ((np.random.rand(m,n)>0.5).astype(int)*2-1)/m**0.5 
                hessian_inv += np.linalg.inv(self.A.T@S.T@S@self.A/n+self.l*np.eye(self.A.shape[1]))
            return hessian_inv/q 
        if method=='sample_debias':
            m = 2
            n = self.A.shape[0]
            d = self.A.shape[1]
            H = self.A.T@self.A/n
            while m<d:
                S = ((np.random.rand(m,d)>0.5).astype(int)*2-1)/m**0.5 
                z = -5*self.l/12
                Sm = np.trace(np.linalg.inv(S@H@S.T-z*np.eye(m)))/m
                if (Sm)>(1/self.l):
                    break
                else: 
                    m = 2*m 

            def Sm(lam):
                sketch_dim  = S.shape[0]
                return np.trace(np.linalg.inv(S@H@S.T+lam*np.eye(sketch_dim)))/sketch_dim

            init_range = np.array([5*self.l/12,self.l])
            assert Sm(init_range[0])>=1/self.l
            assert Sm(init_range[1])<=1/self.l
            while np.abs(Sm(init_range.mean())-1/self.l)>1e-3:
                if Sm(init_range.mean())>1/self.l:
                    init_range[0] = init_range.mean()
                elif Sm(init_range.mean())<1/self.l:
                    init_range[1] = init_range.mean()
                
            de_e = init_range.mean()/self.l

            n = self.A.shape[0]
            hessian_inv = 0
            for _ in range(q):
                # S = np.random.normal(loc=0,scale=1/(m**0.5),size=(m,n))
                S = ((np.random.rand(m,n)>0.5).astype(int)*2-1)/m**0.5 
                hessian_inv += np.linalg.inv(de_e*self.A.T@S.T@S@self.A/n+self.l*np.eye(self.A.shape[1]))
            return hessian_inv/q 
        if method=='feature':
            d = self.A.shape[1]
            n = self.A.shape[0]
            m = 4*(int(dl)+1)
            hessian_inv = 0
            for _ in range(q):
                S = np.random.normal(size=(m,d))/m**0.5 
                hessian_inv += S.T@np.linalg.inv(S@((self.A.T@self.A)/n)@S.T+self.l*np.eye(m))@S
            return hessian_inv/q 
        if method=='feature_debias':
            m = 2
            n = self.A.shape[0]
            d = self.A.shape[1]
            H = self.A.T@self.A/n
            while m<d:
                S = np.random.normal(loc=0,scale=1/(m**0.5),size=(m,d))
                z = -5*self.l/12
                Sm = np.trace(np.linalg.inv(S@H@S.T-z*np.eye(m)))/m
                if (Sm)>(1/self.l):
                    break
                else: 
                    m = 2*m 

            def Sm(lam):
                sketch_dim  = S.shape[0]
                return np.trace(np.linalg.inv(S@H@S.T+lam*np.eye(sketch_dim)))/sketch_dim

            init_range = np.array([5*self.l/12,self.l])
            assert Sm(init_range[0])>=1/self.l
            assert Sm(init_range[1])<=1/self.l
            while np.abs(Sm(init_range.mean())-1/self.l)>1e-3:
                if Sm(init_range.mean())>1/self.l:
                    init_range[0] = init_range.mean()
                elif Sm(init_range.mean())<1/self.l:
                    init_range[1] = init_range.mean()
                
            hat_l = init_range.mean()

            hessian_inv = 0
            for _ in range(q):
                S = np.random.normal(size=(m,d))/m**0.5 
                hessian_inv += S.T@np.linalg.inv(S@((self.A.T@self.A)/n)@S.T+hat_l*np.eye(m))@S
            return hessian_inv/q 

def line_search(func,x,v):
    t = 100 
    while func.eval(x+t*v)>func.eval(x)+(0.25)*t*func.grad(x)@v:
        t = t*0.5 
        if t<1e-20:
            raise Exception('line search error')
    return t

def newton(A,y,n,l,cvx_opt,eff_dim,max_iter,q,method='sample'):
    if method == 'sample':
        record = []
        f = obj_func(A,y,n,l,q)
        x = np.zeros(A.shape[1])
        iter = 0
        MAX_ITER = max_iter
        while iter<MAX_ITER:
            record.append((iter,f.eval(x)))
            v = -f.hessian_inv(dl=eff_dim,method='sample')@f.grad(x)
            t = line_search(f,x,v)
            x += t*v 
            iter += 1      
        return record
    
    if method == 'sample_debias':
        record = []
        f = obj_func(A,y,n,l,q)
        x = np.zeros(A.shape[1])
        iter = 0
        MAX_ITER = max_iter
        while iter<MAX_ITER:
            record.append((iter,f.eval(x)))
            v = -f.hessian_inv(dl=eff_dim,method='sample_debias')@f.grad(x)
            t = line_search(f,x,v)
            x += t*v 
            iter += 1
        return record
    
    if method == 'feature':
        record = []
        f = obj_func(A,y,n,l,q)
        x = np.zeros(A.shape[1])
        iter = 0
        MAX_ITER = max_iter
        while iter<MAX_ITER:
            record.append((iter,f.eval(x)))
            v = -f.hessian_inv(dl=eff_dim,method='feature')@f.grad(x)
            t = line_search(f,x,v)
            x += t*v 
            iter += 1
        return record
    
    if method == 'feature_debias':
        record = []
        f = obj_func(A,y,n,l,q)
        x = np.zeros(A.shape[1])
        iter = 0
        MAX_ITER = max_iter
        while iter<MAX_ITER:
            record.append((iter,f.eval(x)))
            v = -f.hessian_inv(dl=eff_dim,method='feature_debias')@f.grad(x)
            t = line_search(f,x,v)
            x += t*v 
            iter += 1
        return record

def plot(X, Y, n, lambd, m, d, seed):
    x_init = torch.zeros(d)
    newton_ihs = newton(x_init, X, Y, n, lambd, m, d, seed, method='ihs')[0]
    newton_ihs_s = newton(x_init, X, Y, n, lambd, m, d, seed, method='ihs_s')[0]
    newton(x_init, X, Y, n, lambd, m, d, seed, method='h')
    plt.plot(torch.Tensor(newton_ihs).cpu().numpy()[:,0],torch.Tensor(newton_ihs).cpu().numpy()[:,1])
    plt.plot(torch.Tensor(newton_ihs_s).cpu().numpy()[:,0],torch.Tensor(newton_ihs_s).cpu().numpy()[:,1])
    plt.yscale('log')
    plt.legend(['IHS','IHS With Shrinkage'])
    plt.savefig('ihs.pdf')

def get_x_axis(data_set):
    x_max = -1
    length = 0
    for i in data_set: 
        if i[-1][0]>x_max:
            length = len(i)
            x_max = i[-1][0]
    if length == 2:
        return np.linspace(0, x_max, 2)
    return np.linspace(0, x_max, max(length,100))

def interpolate(data, axis, optimal):
    numbers = np.zeros((len(data), len(axis)))
    for i in range(len(data)):
        numbers[i] = get_numbers(data[i], axis, optimal[i]) 
    assert numbers.shape == (len(data), len(axis))
    mean = np.quantile(numbers, 0.5, axis=0)
    error_l = np.quantile(numbers, 0.2, axis=0)
    error_u = np.quantile(numbers, 0.8, axis=0)  
    return (mean, error_l, error_u) 

def get_numbers(data, axis, optimal=None): 
    data = np.array(torch.Tensor(data))
    if optimal is None:
        y = np.abs(data[:,1]-data[-1,1])/np.abs(data[-1,1])
    else:
        optimal = np.array(optimal)
        y = np.abs(data[:,1]-optimal)
    res = np.interp(axis,data[:,0],y)
    return res

def plot_multi_realdata():
    n=10000; d=500; l=1e-3; q=100
    sample = []
    sample_debias = []
    x_axis = {}
    optimals = []
    for i in range(10):
        print(i)
        A, y, eff_dim = create_data(n,d,l,seed=i)
        max_iter=10     
         
        # cvx check
        x_cvx = cp.Variable(d)
        prob = cp.Problem(cp.Minimize(cp.sum_squares(A@x_cvx-y)/(2*n)+(l/2)*cp.sum_squares(x_cvx)))
        prob.solve(solver='CLARABEL')
        cvx_opt = prob.value
        optimals.append(cvx_opt)

        k = newton(A,y,n,l,cvx_opt,eff_dim,max_iter,q,method='feature')
        k1 = newton(A,y,n,l,cvx_opt,eff_dim,max_iter,q,method='feature_debias')
        sample.append(k)
        sample_debias.append(k1)

    x_axis['feature'] = get_x_axis(sample)
    x_axis['feature_debias'] = get_x_axis(sample_debias)
    plot_data = {}
    plot_data['feature'] = interpolate(sample, x_axis['feature'], optimals)
    plot_data['feature_debias'] = interpolate(sample_debias, x_axis['feature_debias'], optimals)
    plt.figure(figsize=(100, 100))
    fig, ax = plt.subplots()
    clrs = sns.color_palette("husl", 10)
    ax.plot(x_axis['feature'], plot_data['feature'][0], label='feature', c=clrs[7])
    ax.fill_between(x_axis['feature'], plot_data['feature'][1], plot_data['feature'][2],alpha=0.3, facecolor=clrs[7])
    ax.plot(x_axis['feature_debias'], plot_data['feature_debias'][0], label='feature_debias', c=clrs[4])
    ax.fill_between(x_axis['feature_debias'], plot_data['feature_debias'][1], plot_data['feature_debias'][2],alpha=0.3, facecolor=clrs[4])
    ax.set_yscale('log')
    plt.xlim(left=0,right=6)
    plt.ylim(bottom=1e-2)
    plt.xlabel('Newton Steps', fontsize=20)
    plt.ylabel('Log Optimality Gap', fontsize=20)
    plt.xticks(fontsize=17)
    plt.yticks(fontsize=17)
    plt.tight_layout()
    plt.savefig('rademacher_ridge.pdf')




if __name__ == '__main__':
    plot_multi_realdata()

