import numpy as np
import matplotlib.pyplot as plt
import time
from sklearn.datasets import load_svmlight_file as loader
import pickle

#locals
from problem import DR_s_LR
from projSplitting import ProjSplit
from Tseng import tseng_product
from FRB import frb, frb_var_reduced


########################################################################################
# choose one of the three problems by uncommenting
problem = 'eps'
#problem = 'susy'
#problem = 'real_sim'
########################################################################################
# dataloc must be the folder with the three data files
dataloc = '../../data'
########################################################################################

tstart = time.time()
seed2use = np.random.randint(0,10000) # control seed only for initializing, then randomize

if problem == "eps":
    x_axis_lim = 130
    alpha_const_0 = 1e0
    niter_ps = 5000
    niter_ps_d = 20
    niter_tseng = 20
    iterfrb = 40
    measure_freq_ps = 200
    measure_freq_frb_s = 400
    alpha_const_0 = 1e0
    data_size = "BIG"
    file = dataloc+'/epsilon_normalized'
    convert_to_numpy = True

    Ctseng = 0.56
    dtseng = 0.6
    niter_stseng = 5000
    historyFreq_Stseng = measure_freq_ps

    probability = 0.01
    niter_frbvr = niter_ps
    historyFreq_frbVR = measure_freq_ps

elif problem == "susy":
    x_axis_lim = 100
    alpha_const_0 = 5e0
    niter_ps = 200
    niter_ps_d = 30
    niter_tseng = 20
    iterfrb = 60
    measure_freq_ps = 10
    measure_freq_frb_s = 10
    data_size = "BIG"
    file = dataloc+'/SUSY'
    convert_to_numpy = True

    Ctseng = 0.56
    dtseng = 0.6
    niter_stseng = 200
    historyFreq_Stseng = measure_freq_ps

    probability = 0.01
    niter_frbvr = niter_ps
    historyFreq_frbVR = measure_freq_ps

elif problem == "real_sim":
    x_axis_lim = 5
    alpha_const_0 = 2e0
    niter_ps = 1000
    niter_ps_d = 500
    niter_tseng = 500
    iterfrb = 1000
    measure_freq_ps = 10
    measure_freq_frb_s = 10
    data_size = "small"
    file = dataloc+'/real-sim'
    convert_to_numpy = False

    niter_stseng = 1000
    historyFreq_Stseng = measure_freq_ps
    Ctseng = 0.77
    dtseng = 0.55

    probability = 0.01
    niter_frbvr = niter_ps
    historyFreq_frbVR = measure_freq_ps

# Problem settings
delta = 1e0
kappa = 1e0
lrcoef = 1e-3

######################################################
# data
######################################################

tloaddata = time.time()
if data_size == "BIG":
    try:
        with open(file + '.pkl', 'rb') as handle:
            dataFile = pickle.load(handle)

        X = dataFile['X']
        y = dataFile['y']

    except:
        X, y = loader(file)
        dataFile = {}
        dataFile['X'] = X
        dataFile['y'] = y
        with open(file + '.pkl', 'wb') as handle:
            pickle.dump(dataFile, handle)

else:
    X,y=loader(file)

tloaddata = time.time()-tloaddata
print(f"time to load data: {tloaddata}")

y = 2.0*(y>0) - 1.0
m,d = X.shape

print(f"m: {m},d: {d}")
print(f"X % nz: {100*X.count_nonzero()/(m*d)}")


if convert_to_numpy:
    X = X.toarray()

######################################################
######################################################
######################################################



# create problem instance

dr_s_lr = DR_s_LR(X,y,delta,kappa,lrcoef)

batchsz = 100
L = dr_s_lr.get_L()


##################################################################################################
# Tseng initialization
##################################################################################################

prox1 = lambda z,step: dr_s_lr.project_conePlusBall(z)
prox2 = dr_s_lr.prox_L1
eval_vec_field = dr_s_lr.getStochasticUpdate


class Init:
    def __init__(self):
        pass

np.random.seed(seed2use)
init = Init()
init.z = np.random.normal(0,1,dr_s_lr.num_var)
init.w = np.zeros(dr_s_lr.num_var)
np.random.seed()


##################################################################################################
# FRB var reduced
##################################################################################################
print(f"Running FBF var reduced...")
tau_frb_var_red = (1.0-np.sqrt(1.0-probability))/(2.0*L)

tfrb_var_red = time.time()
frb_vr_results = frb_var_reduced(prox1, prox2, eval_vec_field, init, probability,iter=niter_frbvr, tau=tau_frb_var_red,
            verbose=False, batchsz=batchsz,historyFreq=historyFreq_frbVR)

tfrb_var_red = time.time() - tfrb_var_red
print(f"frb var red run time: {tfrb_var_red}")




##################################################################################################
# STOCHASTIC PS
##################################################################################################
print(f"Running Stochastic PS...")

rho1 = 1e0
alpha1 = 1e0
rho_exp = -0.25
alpha_exp = -0.51
reuseBatch = False
tau = 1e0
dualFac = 1e0




optimality_type = "new"

tps = time.time()
ps = ProjSplit(dr_s_lr, True)
ps.run(niter_ps,rho1,alpha1,rho_exp,alpha_exp,1.0,tau,dualFac,
   measure_freq_ps,batchsz,reuseBatch,optimality_type,False,seed2use)
tps = time.time() - tps
print(f"sps decay run time: {tps}")

ps.prob = [] # free the prob field which includes the matrix which might be huge

##################################################################################################
# STOCHASTIC PS fixed stepsize
##################################################################################################
print(f"Running stochastic PS fixed step...")


tau = 1e0
dualFac = 1e0

optimality_type = "new"

ps_fix = ProjSplit(dr_s_lr, True)
rho1, alpha1, rho_exp, alpha_exp = ps_fix.getFixedStepsize(niter_ps)
tps = time.time()
ps_fix.run(niter_ps,rho1,alpha1,rho_exp,alpha_exp,alpha_const_0,tau,dualFac,
       measure_freq_ps,batchsz,reuseBatch,optimality_type,False,seed2use)
tps = time.time() - tps
print(f"sps fixed run time: {tps}")

ps_fix.prob = [] # free the prob field which includes the matrix which might be huge

##################################################################################################
# DETERMINISTIC PS
##################################################################################################
print(f"Running deterministic PS...")

rho_exp = 0.0
alpha_exp = 0.0
reuseBatch = False
tau = 1e0

rho1 = 0.9*L**(-1)
alpha1 = 1.0 # alpha1 is found as part of projection onto the hplane in the deterministic algorithm
             # so this value given here is not used
dualFac = 1e0
batchsz_det = "full"
alpha_const_0 = 1e0


ps_d = ProjSplit(dr_s_lr, True)
optimality_type = "new"
tps = time.time()
ps_d.run(niter_ps_d,rho1,alpha1,rho_exp,alpha_exp,alpha_const_0,tau,dualFac,
       1,batchsz_det,reuseBatch,optimality_type,True,seed2use)
tps = time.time() - tps
print(f"ps d run time: {tps}")

ps_d.prob = []  # free the prob field which includes the matrix which might be huge



##################################################################################################
# Tseng
##################################################################################################
print("Running Tseng's method...")

batchsz_tseng = m

theta=0.8

tTseng = time.time()
tseng_out = tseng_product(None, prox1, prox2, eval_vec_field, init,theta=theta,
                          getFuncVals=False,batchsz=batchsz_tseng,iter=niter_tseng,verbose=False)
tTseng = time.time() - tTseng
print(f"Tseng run time: {tTseng}")

##################################################################################################
# Stochastic Tseng
##################################################################################################
print(f"Running Stochastic Tseng's method...")
batchsz_tseng = batchsz
getAvResid=True

tStseng = time.time()
Stseng_out = tseng_product(None, prox1, prox2, eval_vec_field, init,getAvResid=getAvResid,
                      getFuncVals=False,batchsz=batchsz_tseng,iter=niter_stseng,verbose=False,
                      doBT=False,Cstep=Ctseng,dstep=dtseng,historyFreq=historyFreq_Stseng)
tStseng = time.time() - tStseng
print(f"stoch Tseng run time: {tStseng}")


##################################################################################################
# FRB
##################################################################################################
print(f"Running frb...")

np.random.seed(seed2use)
init = Init()
init.z = np.random.normal(0,1,dr_s_lr.num_var)
init.w = np.zeros(dr_s_lr.num_var)
np.random.seed()

tfrb = time.time()
outfrb = frb(prox1, prox2, eval_vec_field, init, iter=iterfrb, delta=0.8)
tfrb = time.time() - tfrb
print(f"frb run time {tfrb}")


tend = time.time()
print(f"total running time: {tend-tstart}")


##################################################################################################
# Plotting
##################################################################################################

msz = 10
fsz = 17
if problem=="eps":
    mkvry = 2
else:
    mkvry = 5
plt.semilogy(ps.tstamp, ps.OptCond,label = "SPS-decay (ours)",marker=None,markevery=mkvry,markersize=msz)
plt.semilogy(ps_fix.tstamp, ps_fix.OptCond, label="SPS-fixed (ours)", marker='v', markevery=mkvry,markersize=msz)
plt.semilogy(ps_d.tstamp, ps_d.OptCond, label="deterministic PS",marker='D',markersize=msz,markevery=mkvry)
plt.semilogy(tseng_out.times,tseng_out.residuals, label="Tseng's method",marker='*',markersize=msz,markevery=mkvry)
plt.semilogy(outfrb.tstamps, outfrb.residuals, label="FRB",marker='^',markersize=msz,markevery=mkvry)

plt.semilogy(Stseng_out.times, Stseng_out.residAvs, label="S-Tseng", marker='h', markersize=msz,
                 markevery=mkvry)

plt.semilogy(frb_vr_results[0], frb_vr_results[1], label="FRB-VR", marker='v', markevery=mkvry,
                 markersize=msz)

plt.xlabel("running time (s)",fontsize=fsz)

plt.ylabel("approximation residual",fontsize=fsz)

plt.xlim([0,x_axis_lim])

plt.legend(loc='upper right',fontsize=fsz)

plt.xticks(fontsize=fsz)
plt.yticks(fontsize=fsz)
plt.tight_layout()

plt.show()

