import numpy as np
import math
import os
import pickle
import time
from scipy.linalg import eigh
from scipy.linalg import sqrtm

#from sys import platform as sys_pf
#if sys_pf == 'darwin':
##    import matplotlib
#    matplotlib.use("TkAgg")
from matplotlib import pyplot as plt
from sklearn.datasets import load_svmlight_file



def sqnorm(x): return np.linalg.norm(x)**2


def x_mean(X): return np.mean(X, axis=0)


def normalized_matrix(width, d):
    A = np.zeros((width, d))
    for j in range(width):
        v = np.random.randn(d)
        v = v - np.dot(A.T,np.dot(A,v))
        A[j,:] = v/np.linalg.norm(v)
    return A

def make_fg_quadratic(T, d, type, mu):

    if type == 0:
        xi_star = np.random.randn(T,d)
        A = [normalized_matrix(1, d) * np.sqrt(1-mu) for i in range(T)]
    elif type == 1:
        xi_star = np.random.randn(T,d)
        A = [normalized_matrix(10, d)*np.sqrt(1-mu) for i in range(T)]
    elif type == 2:
        xi_star = 10000*np.random.randn(T,d)
        A = [normalized_matrix(1, d)*np.sqrt(1-mu) for i in range(T)]
    elif type == 3:
        xi_star = 100*np.random.randn(T,d)
        A = [normalized_matrix(10, d)*np.sqrt(1-mu) for i in range(T)]

    def f_quad(x, i): return 0.5*sqnorm(np.dot(A[i],x-xi_star[i,:])) + 0.5*mu*sqnorm(x-xi_star[i,:])

    def g_quad(x, i):
        return np.dot(A[i].T,np.dot(A[i],x-xi_star[i,:])) + mu*(x-xi_star[i,:])

    v = -1*np.sum(np.asarray([g_quad(np.zeros(d), i) for i in range(T)]), axis=0)

    xstar = np.linalg.solve(np.sum(np.asarray([mu*np.eye(d) + np.dot(A[i].T,A[i]) for i in range(T)]),axis=0),v)
    fstar = np.mean(np.asarray([f_quad(xstar, i) for i in range(T)]))

    return f_quad, g_quad, fstar, xstar



def make_fg_logreg(A, b, mu):
    def f_log(x, i): return np.log(1 + np.exp(np.dot(x, A[i, :]) * b[i])) + mu*0.5*sqnorm(x)

    def g_log(x, i):
        return A[i, :] * b[i] / (1 + np.exp(-np.dot(x, A[i, :]) * b[i])) + mu*x

    return f_log, g_log


def full_grad(g, x, indices):
    return np.mean(np.asarray([g(x,i) for i in indices]), axis = 0)

def create_blocks(n, block_size):
    return [(block_size*i, block_size*i+ block_size) for i in range(int(n/block_size))]


def get_Cis(A, blocks):
    R = []
    n, d = A.shape
    for b in blocks:
        U, S, V = np.linalg.svd(A[b[0]:b[1],:]/np.sqrt(b[1]-b[0]))
        R.append((S, V[:len(S),:]))
    return R

def L_power_product(CC, gg, mu, poww):
    (S, V) = CC
    return (mu**(poww))*gg + np.dot(V.T,np.dot(np.diag(((S*S+mu)**(poww))-(mu**(poww))), np.dot(V,gg)))

def spars_step(CC, grad, mu, probs, simple):
    d=len(grad)
    spars_vec = np.zeros(d)
    spars_vec_up = np.zeros(d)

    if np.abs(np.sum(probs) - 1) < 1e-4:
        j = np.random.choice(d, None, False, probs / np.sum(probs))
        spars_vec_up[j] = 1.0
        spars_vec[j] = 1 / probs[j]
    else:
        for i in range(d):
            if np.random.rand()<probs[i]:
                spars_vec_up[i]=1.0
                spars_vec[i]=1/probs[i]
    if simple:
        return spars_vec*grad, spars_vec_up*grad
    else:
        g1 = L_power_product(CC, grad, mu, -0.5)
        g2 = spars_vec*g1

        sg1 = spars_vec_up*g1
        Lsg1 =  L_power_product(CC, sg1, mu, 1)
        sLsg1 = spars_vec_up*Lsg1
        LsLsg1 = L_power_product(CC, sLsg1, mu, -0.5)

        return L_power_product(CC, g2, mu, 0.5), LsLsg1




#computes p once c is given
def impP(m,c):
    return m/(m+c)


#computes p for importance minibatch sampling so that ESO is satisfied
def getP(M,tau):
    m=np.diag(M)
    c_lb=0.0
    c_ub=1.0*(sum(m))/tau
    epsil=c_ub/(1e10)
    c=(c_lb+c_ub)/2
    while c_ub-c_lb>epsil:
        if np.sum(impP(m,c))>tau:
            c_lb=c
        else:
            c_ub=c
        c=(c_lb+c_ub)/2
    return impP(m,c)


def create_probs_and_stepsize(A, blocks, mu, importance, tau, simple_spars, nmu):
    n, d = A.shape
    alphas = np.zeros(len(blocks))
    probs = np.zeros((len(blocks), d))
    assert not(simple_spars and importance)

    for i, bl in enumerate(blocks):
        Lmat = np.dot(A[bl[0]:bl[1],:].T, A[bl[0]:bl[1],:]/(bl[1]-bl[0])) + mu*np.eye(d)
        if simple_spars:
            probs[i] = np.ones(d) / d * tau
            alphas[i] = eigh(Lmat, eigvals = (d-1,d-1))[0][0] * d / tau
        else:
            if importance:
                probs[i] = getP2(np.diag(Lmat),tau, nmu)
            else:
                probs[i] = np.ones(d)/d*tau

            alphas[i] = np.max(np.diag(Lmat)*(1/probs[i]-1))

    LLL = 6*np.max(alphas)/len(blocks)
    return probs, LLL

def getP2(DL,tau, nmu):
    d = len(DL)
    return getP(np.diag(DL + nmu*np.ones(d)), tau)

def getP3(DL, tau, nmu):
    d = len(DL)
    probs_curr = np.ones(d)*tau/d
    r_curr = np.max(1.0/probs_curr) + np.max(DL*(1/probs_curr-1))/nmu

    counter = 2
    probs_new = create_probs(DL,tau, counter)
    r_new = np.max(1.0/probs_new) + np.max(DL*(1/probs_new-1))/nmu
    
    while r_new<r_curr and counter<d:
        counter += 1
        probs_curr = probs_new*1.0
        r_curr = r_new*1.0

        probs_new = create_probs(DL, tau, counter)
        r_new = np.max(1.0 / probs_new) + np.max(DL * (1 / probs_new - 1)) / nmu

    return probs_curr


def create_probs(DL,tau, curr):
    DN = DL*1.0
    DN2 = DL*1.0
    DN.sort()
    DN3 = np.where(DN2>DN[-curr], DN2, DN[-curr])

    if tau == 1:
        return DN3/np.sum(DN3)
    else:
        return getP(np.diag(DN3), tau)



def get_alpha_probs(A, blocks, mu, importance, tau, simple_spars, nmu):
    n,d = A.shape
    probs, LLL = create_probs_and_stepsize(A, blocks, mu, importance, tau, simple_spars, nmu)
    Lmat = np.dot(A.T, A) / n + mu * np.eye(d)
    L = eigh(Lmat, eigvals=(d - 1, d - 1))[0][0]
    alpha = 1/(L + LLL)
    return probs, alpha, L, LLL

def objective(f, x, n):
    fval = 0
    for i in range(n):
        fval += f(x, i)/(n)
    return fval


def get_data(dataset):
    data_path = 'datasets/'
    if dataset == "mushrooms":
        datapath = data_path + dataset + '2.txt'
    else:
        datapath = data_path + dataset + '.txt'

    data = load_svmlight_file(datapath)
    return data[0].toarray(), data[1]


def normalize_data(A):
    n, d = A.shape
    B = A.copy()
    for i in range(n):
        if np.linalg.norm(B[i, :]) == 0.0:
            continue
        else:
            B[i, :] /= np.linalg.norm(B[i, :])/2.0
    return B


def rearrange_data(A, method, b):
    n, _ = A.shape
    if method == "random":
        permutation = np.random.permutation(n)
        B = A[permutation, :]
        bb = b[permutation]
    elif method == "same":
        B = A.copy()
        B = B[np.argsort(b), :]
        bb = b[np.argsort(b)]
    return B, bb


def get_stepsize_saga(v=0, p=0, pagg=0, pwork=1, omega=0, n=0, T=0, mu=0):
    # only for simplest version
    minf = T*(1-pagg)/(4*v[0]+mu/p[0])
    psi_term = pagg/(4*omega+mu/T)
    return np.minimum(minf, psi_term)


def create_it(T, skip_it, tau=1):
    it = np.zeros(1 + T // skip_it)
    for i in range(len(it)):
        it[i] = i*skip_it*tau
    return it

def create_distances(Xlist, X):
    dist = np.zeros(len(Xlist))

    for i in range(len(Xlist)):
        T,_  = Xlist[i].shape
        for t in range(T):
            dist[i] += np.linalg.norm(Xlist[i][t] - X[t])**2

    return dist

def createfilename(experiment, dataset, T, mu, backup=False):
    if backup:
        return "pickles/backup/" + experiment + "_" + dataset + "_" + str(T) + "_" + "_" + str(mu) + "_" + str(time.time()) + ".pickle"
    else:
        return "pickles/" + experiment + "_" + dataset + "_" + str(T) + "_" + "_" + str(mu) + "_" + "_" + ".pickle"


def load_pickle(experiment, dataset, T, mu):
    filename = createfilename(experiment, dataset, T, mu)
    pickle_in = open(filename, "rb")
    return pickle.load(pickle_in)


linestyles = ['-', '--', ':', '-.', '--', ':', '-.', '-']
markers = ['*', 'o', 's', 'x', 'd', 'v', 'P', '1']
colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']


def visualize(fvals, it, plot_name, plottitle, labels, plots_path=os.getcwd()+'/plots/', linestyles=linestyles,
              markers=markers, colors=colors, alphas=None, ppe=0,  muT=0, fstar = None, maxit = None):


    eps = 1e-12
    plt.figure()
    counter = 0

    minF = np.min(np.asarray([fvals[i][-1] for i in range(len(fvals))]))
    if  minF< 0:
        for i in range(len(fvals)):
            nl = len(fvals[i])
            fvals[i] = fvals[i][:int(nl*3.0/5.0)] - minF
            it[i] = it[i][:int(nl*3.0/5.0)]


    for counter, ff in enumerate(fvals):
        nl = len(ff)
        mar_mult = 1.0*it[0][-1]/it[counter][-1]
        plt.semilogy(it[counter], ff, label=labels[counter], linestyle=linestyles[counter],
                     color=colors[counter], marker=markers[counter], markersize=12, markevery=int(math.ceil(len(it[counter]) / 10*mar_mult)))


        if alphas is not None:
            rate = alphas[group[ii]]*muT
            ff = (1-rate)**(ppe*np.asarray(it[0]))
            labels.append("theory")
            nl = len(ff)
            nn = int(nl*3.0/5.0)
            plt.semilogy(it[counter][:nn], ff[:nn], label="T"+labels[counter], linestyle=linestyles[counter],
                         color=colors[counter], marker="s", markevery=int(math.ceil(len(it[counter]) / 15)))


    bottom, top = plt.ylim()
#    plt.ylim([np.maximum(0.9, 1e-13), np.minimum(top, 1.01)])
    plt.ylim([np.maximum(bottom, 1e-13), np.minimum(top, 100.0)])

    bottom2, top2 = plt.xlim()
    plt.xlim([0, it[-1][-1]])

    plt.locator_params(axis='x', nbins=6)

    plt.legend(fontsize=17, loc='best')
    if "epoch" in plot_name:
        plt.xlabel('Data passes', fontsize='xx-large')
    elif "communication" in plot_name:
        plt.xlabel('Rounds of communication', fontsize='xx-large')
    elif "bits" in plot_name:
        plt.xlabel('Coordinates sent to server', fontsize='xx-large')
    else:
        plt.xlabel('Iteration', fontsize=20) #fontsize='xx-large')#

#    plt.xticks([0, 500, 1000, 1500, 2000], fontsize=17)
    plt.xticks(fontsize=17)
    plt.yticks(fontsize=17)
    plt.ylabel('Relative suboptimality', fontsize=20)#fontsize='xx-large')#
    plt.title(plottitle, fontsize=20) #fontsize='xx-large')#
    plt.tight_layout()
    plt.savefig('{}{}{}.pdf'.format(plots_path, plot_name, plottitle.replace(':', '')))




def visualize_dist(dist_0, dist_inf, omegas_np):
    omegas_np[0] = omegas_np[1]/2

    counter = 0
    plt.plot(omegas_np, dist_0, label= r'$||x(\lambda)-x(0)||^2$' , linestyle=linestyles[counter], color=colors[counter],
                 marker=markers[counter], markevery=1)

    counter = 1
    plt.plot(omegas_np, dist_inf, label= r'$||x(\lambda)-x(\infty)||^2$', linestyle=linestyles[counter], color=colors[counter],
                 marker=markers[counter], markevery=1)

    plt.legend(fontsize='x-large', loc='best')

    plt.xlabel(r'$\lambda$', fontsize='x-large')

    plt.xlim([omegas_np[0], omegas_np[-1]*1.1])
    plt.xscale('log')

    plt.tight_layout()
    plt.savefig('{}evolution.pdf'.format(os.getcwd()+'/plots/'))

    return


def visualize_(comm, ps):

    counter = 0

    fig1, ax1 = plt.subplots()
    ax1.plot(ps, comm, label='_nolegend_', linestyle=linestyles[counter], color=colors[counter],
                 marker=markers[counter], markevery=1)
    ax1.set_xscale('log')
    ax1.set_yscale('log')

    ax1.set_xticks([0.01, 0.03, 0.1, 0.5])
    ax1.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())


    plt.xlabel(r'$p$', fontsize='x-large')

    plt.ylabel(r'Communication rounds for $\varepsilon$ solution', fontsize='x-large')
    plt.tight_layout()
    plt.savefig('{}commp2log.pdf'.format(os.getcwd()+'/plots/'))

    return