from utils import generate_u_batch_accelerated
import numpy as np
import argparse
import matplotlib.pyplot as plt
import random
import scipy.io as scio
from numba import jit
font = {'family':'Times New Roman', 'weight':'normal','size':20,}
legend_font = {'family':'Times New Roman', 'weight':'normal','size':15,}

cmap = { 0:'c',1:'b',2:'y',3:'g',4:'r', 5:'k'}

ds = [5000, 10000, 50000]
kstar=5
k = 500




def plot_figure(fn, qs, etas, s2s, nm, num, axes, iter=100000, r=0.1, lam=10, k=2, miu=1e-4, SEED=1, ccurves=[], ddists=[], d=None):
    vec = np.ones(d)
    v = np.zeros(d)
    v[-kstar:] = 1/(kstar)*np.arange(1, kstar+1) 
    def f(x):
        y = (x - v) * (vec) 
        return 0.5 * (y) @ (y).T 

    def f_bis_batch(x):
        y = (x - v) * (vec)[None] 
        fx = 0.5 * np.sum(y ** 2, axis=1, keepdims=True) 
        return fx

    sol = v



    q = qs
    s2 = s2s
    line = 0
    random.seed(SEED)
    np.random.seed(SEED)

    batch_size = int(2014/2)
    x = np.ones(d)/d
    x[-kstar:] = 0
    dist = np.linalg.norm(x - sol)
    performance_ZOHT = [[] for i in range(q.shape[0]*s2.shape[0])]
    for n in range(q.shape[0]* s2.shape[0]):
        performance_ZOHT[n] = [[0, f(x), dist]]
    for q_i in range(q.shape[0]):
        for s2_i in range(s2.shape[0]):
                x_new = x
                judge = 1
                t = 1
                for m in range(np.int(iter/q[q_i])):
                    current_eta = etas[q_i, s2_i]
                    if q[q_i]  % batch_size != 0:
                        raise(ValueError('batch_size should be a divider of q'))
                    num_batches = int(q[q_i]/batch_size)
                    gradient = np.zeros(num)
                    for _ in range(num_batches):
                        u = generate_u_batch_accelerated(s2[s2_i], num, batch_size)
                        func_1 = f(x_new)
                        func_2 = f_bis_batch(x_new+miu*u)
                        gradient_i = num/miu*(func_2-func_1)*u
                        gradient += np.sum(gradient_i,axis=0)/batch_size
                    gradient /= num_batches
                    
                    x_new = x_new - current_eta * gradient
                    top_k_idx = np.argsort(-np.abs(x_new))[0:k]
                    x_temp = np.zeros_like(x)
                    x_temp[top_k_idx] = x_new[top_k_idx]
                    x_new = x_temp

                    dist = np.linalg.norm(x_new - sol)
                    performance_ZOHT[line].append([t*q[q_i],f(x_new), dist])
                    judge = f(x_new)
                    print('Estimated f(x_k): %f  iters: %d' %
                        (judge, t*q[q_i]))
                    t=t+1
                ccurves[q_i].append(np.array(performance_ZOHT[line]))
                line = line + 1



if __name__ == "__main__":


    plt.figure()
    for d in ds:
        qs = np.array([2014])
        s2s = np.array([d])
        s = 2*k + kstar
        epsilon_f = 2*d /(qs[:, None] *(s2s[None]+2)) *   (( s- 1) * (s2s[None] - 1) /(d-1) + 3) + 2
        rho_sq = 1- 1/(4 * epsilon_f + 1)
        k_star = np.array([1, 2, 3, 4, 5, 6, 7, 8])
        gamma = np.sqrt(1 + (k_star/k  + np.sqrt((4+ k_star/k) * k_star/k))/2)
        rho = np.sqrt(rho_sq)
        etas = 1*np.ones((qs.shape[0], s2s.shape[0]))*   1/ ( 4 * epsilon_f + 1 ) 

        curves = dict()
        dists = dict()
        for i, _ in enumerate(ds):
            curves[i], dists[i] = [], []
        for i in range(1):
            plot_figure(fn = 'port5.txt',  qs=qs, etas=etas, s2s=s2s,nm='am5_q', num=d, axes=[0, 6000, 0.04, 0.106], miu=1e-9, k=k, SEED=i, ccurves=curves, ddists=dists, d=d)

        for i, q in enumerate(qs): 
            avg = np.zeros(len(curves[i][0]))
            for j, _ in enumerate(curves[i]):
                avg += curves[i][j][:, 1]
            avg /= len(curves[i])
            plt.plot(curves[i][0][:, 0], avg[:], label=f"d={d}")


 

    plt.xlabel('Function Evaluations', font)
    plt.ylabel('$f(x)$', legend_font)
    plt.legend(prop=legend_font)
    plt.savefig(f'result/f_ds.png')

    plt.figure()
    for d in ds:
        qs = np.array([2014])
        s2s = np.array([d])
        s = 2*k + kstar
        epsilon_f = 2*d /(qs[:, None] *(s2s[None]+2)) *   (( s- 1) * (s2s[None] - 1) /(d-1) + 3) + 2
        rho_sq = 1- 1/(4 * epsilon_f + 1)
        k_star = np.array([1, 2, 3, 4, 5, 6, 7, 8])
        gamma = np.sqrt(1 + (k_star/k  + np.sqrt((4+ k_star/k) * k_star/k))/2)
        rho = np.sqrt(rho_sq)
        etas = 1*np.ones((qs.shape[0], s2s.shape[0]))*   1/ ( 4 * epsilon_f + 1 ) 

        curves = dict()
        dists = dict()
        for i, _ in enumerate(ds):
            curves[i], dists[i] = [], []
        for i in range(1):
            plot_figure(fn = 'port5.txt',  qs=qs, etas=etas, s2s=s2s,nm='am5_q', num=d, axes=[0, 6000, 0.04, 0.106], miu=1e-9, k=k, SEED=i, ccurves=curves, ddists=dists, d=d)


        for i, q in enumerate(qs): 
            avg = np.zeros(len(curves[i][0]))
            for j, _ in enumerate(curves[i]):
                avg += curves[i][j][:, 2]
            avg /= len(curves[i])
            plt.plot(curves[i][0][:, 0], avg[:], label=f"d={d}")

    plt.xlabel('Function Evaluations', font)
    plt.ylabel('$\|x - x^*\|$', legend_font)
    plt.legend(prop=legend_font)
    plt.savefig(f'result/dis_ds.png')


