from distributed_pcg.utils import read_dataset
import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import sys
import cvxpy as cp



def create_data(n,d_list,l,seed,a):
    np.random.seed(seed)
    A_list = [] 
    dl_list = []
    l_list= []
    for i in range(len(d_list)):
        A = np.random.normal(size=(n,d_list[i]))
        AtA  = A.T@A 
        U,_,Vt = np.linalg.svd(AtA)
        alpha = np.diag([(k/d_list[i])**a for k in range(d_list[i])]) #np.diag([(a)**i for i in range(d_list[i])])
        A = U@(alpha)**0.5@Vt
        AtA = A.T@A 
        eff_dim = np.trace((AtA/n)@np.linalg.inv((AtA/n)+l*np.eye(d_list[i])))
        A_list.append(A)
        dl_list.append(eff_dim)
        l_list.append(1e-3*np.random.rand()+1e-3)

    return A_list, l_list
    

def compute_inverse(A_list,d_list,n,l):
    original_list = []
    ours_list = []
    ideal_list = [] 
    for i in range(len(A_list)):
        A = A_list[i]
        H = A.T@A/n
        d = d_list[i]
        dl = np.trace(H@np.linalg.inv(H+l*np.eye(d)))

        m = int(2*dl)+1
        S = np.random.normal(loc=0,scale=1/(m**0.5),size=(m,d))

        def Sm(lam):
            sketch_dim  = S.shape[0]
            return np.trace(np.linalg.inv(S@H@S.T+lam*np.eye(sketch_dim)))/sketch_dim


        # compute \hat \lambda 
        init_range = np.array([5*l/12,l])
        assert Sm(init_range[0])>=1/l
        assert Sm(init_range[1])<=1/l
        while np.abs(Sm(init_range.mean())-1/l)>1e-3:
            if Sm(init_range.mean())>1/l:
                init_range[0] = init_range.mean()
            elif Sm(init_range.mean())<1/l:
                init_range[1] = init_range.mean()
            
        hat_l = init_range.mean()

        # check for inverse error 
        q = 500 
        tilde_l = l*(1-dl/m)
        no_debias = 0 
        debias = 0
        ours = 0
        for i in range(q):
            S = np.random.normal(loc=0,scale=1/(m**0.5),size=(m,d))
            no_debias += S.T@np.linalg.inv(S@H@S.T+l*np.eye(m))@S 
            debias += S.T@np.linalg.inv(S@H@S.T+tilde_l*np.eye(m))@S
            ours += S.T@np.linalg.inv(S@H@S.T+hat_l*np.eye(m))@S
        no_debias = no_debias/q
        debias = debias/q
        ours = ours/q

        true = np.linalg.inv(H+l*np.eye(d))
        original_list.append(((true-no_debias)**2).sum()/(d**2))
        ideal_list.append(((true-debias)**2).sum()/(d**2))
        ours_list.append(((true-ours)**2).sum()/(d**2))
    return original_list, ideal_list, ours_list



def get_x_axis(data_set):
    x_max = -1
    length = 0
    for i in data_set: 
        if i[-1][0]>x_max:
            length = len(i)
            x_max = i[-1][0]
    if length == 2:
        return np.linspace(0, x_max, 2)
    return np.linspace(0, x_max, max(length,100))

def interpolate(data):
    numbers = np.zeros((len(data), len(data[0])))
    for i in range(len(data)):
        numbers[i] = np.array(data[i])
    mean = np.quantile(data, 0.5, axis=0)
    error_l = np.quantile(data, 0.2, axis=0)
    error_u = np.quantile(data, 0.8, axis=0)  
    return (mean, error_l, error_u) 


def plot_multi_realdata():
    n=10000; d_list = [100,200,300,400,500,600,700,800]; l=1e-3
    dl_list = []
    dl_b_list = [] 
    for i in range(10):
        print(i)
        A_list, l_list_1 = create_data(n,d_list,l,seed=i,a=i*0.5)
        original, ideal, ours = compute_inverse(A_list,d_list,n,l)
        dl_list.append(np.array(original))
        dl_b_list.append(ours)


    dl_plot_data = interpolate(dl_list)
    dl_b_plot_data = interpolate(dl_b_list)
    plt.figure(figsize=(100, 100))
    fig, ax = plt.subplots()
    clrs = sns.color_palette("husl", 10)
    print(f'{dl_plot_data=}')
    print(f'{dl_b_plot_data=}')
    ax.plot(np.array(d_list),dl_plot_data[0], label=r'$\frac{1}{d^2}\|H(\lambda)^{-1}-H_S(\lambda)^{-1}\|_F^2$ (original)', c=clrs[3])
    ax.fill_between(np.array(d_list),dl_plot_data[1], dl_plot_data[2],alpha=0.3, facecolor=clrs[3])
    ax.plot(np.array(d_list),dl_b_plot_data[0], label=r'$\frac{1}{d^2}\|H(\lambda)^{-1}-H_S(\hat \lambda)^{-1}\|_F^2$ (ours)', c=clrs[1])
    ax.fill_between(np.array(d_list),dl_b_plot_data[1], dl_b_plot_data[2],alpha=0.3, facecolor=clrs[1])
    ax.legend(fontsize=18, loc="upper right")
    plt.xlabel('d', fontsize=20)
    plt.ylabel('bias', fontsize=20)
    plt.xticks(fontsize=17)
    plt.yticks(fontsize=17)
    plt.tight_layout()
    plt.savefig('bias_vs_d.pdf')




if __name__ == '__main__':
    plot_multi_realdata()

