#%%
import numpy as np
random = np.random.default_rng(0)
from matplotlib import pyplot as plt
from math import log

#%%
n = 1000
s = 0.1
r = 10

delta = 0.0001
loginvdelta = log(1/delta)
t = np.arange(1,n+1)
bound = s**2 * t* (1+ 2*np.sqrt(loginvdelta) + 2*loginvdelta)
plt.figure()
d_range = [int(d) for d in 2**np.arange(2,6,2)]
for d in d_range:
    X = random.standard_normal(size= (r, d, n))

    X_n =  X / np.linalg.norm(X, axis= 1, keepdims= True)
    # X_n =  X / np.max(np.linalg.norm(X, axis= 0))
    Gram_s = np.einsum("lki,lkj -> lij", X_n, X_n)


    sigma= 0.01
    noise = s* random.standard_normal(size= (r,n))
    
    process_s = np.einsum("kij,kj->kij", X_n , noise)
    process_s_cumsum = np.cumsum(process_s, axis= 2)
    
    plt.plot(np.linalg.norm(process_s_cumsum, axis= 1).T, label = f"d={d}")
# plt.plot(bound)
# plt.legend()



  



# %%
