import numpy as np
import math
import os
import pickle
import time


from sys import platform as sys_pf
if sys_pf == 'darwin':
    import matplotlib
    matplotlib.use("TkAgg")
from matplotlib import pyplot as plt
from sklearn.datasets import load_svmlight_file



def sqnorm(x): return np.linalg.norm(x)**2


def x_mean(X): return np.mean(X, axis=0)


def psi_func(X, x_bar, omega, t): return omega*sqnorm(X[t, :]-x_bar)*0.5


def psi_grad(X, x_bar, omega, t): return omega*(X[t, :]-x_bar)


def make_fg_logreg(A, b, mu):
    def f_log(x, i): return np.log(1 + np.exp(np.dot(x, A[i, :]) * b[i])) + mu*0.5*sqnorm(x)

    def g_log(x, i):
        return A[i, :] * b[i] / (1 + np.exp(-np.dot(x, A[i, :]) * b[i])) + mu*x

    return f_log, g_log


def objective(f, psi, X, m, omega):
    T, d = X.shape
    x_bar = x_mean(X)
    fval = 0
    for t in range(T):
        for i in range(m):
            fval += f(X[t, :], t*m+i)/(m*T)
        fval += psi(X, x_bar, omega, t)
    return fval


def get_data(dataset):
    data_path = 'datasets/'
    datapath = data_path + dataset + '.txt'
    data = load_svmlight_file(datapath)
    return data[0].toarray(), data[1]


def normalize_data(A):
    n, d = A.shape
    B = A.copy()
    for i in range(n):
        B[i, :] /= np.linalg.norm(B[i, :])/4.0
    return B


def rearrange_data(A, method, b):
    n, _ = A.shape
    if method == "random":
        permutation = np.random.permutation(n)
        B = A[permutation, :]
        bb = b[permutation]
    elif method == "same":
        B = A.copy()
        B = B[np.argsort(b), :]
        bb = b[np.argsort(b)]
    return B, bb


def get_stepsize_saga(v=0, p=0, pagg=0, pwork=1, omega=0, n=0, T=0, mu=0):
    # only for simplest version
    minf = T*(1-pagg)/(4*v[0]+mu/p[0])
    psi_term = pagg/(4*omega+mu/T)
    return np.minimum(minf, psi_term)


def create_it(T, skip_it, tau=1):
    it = np.zeros(1 + T // skip_it)
    for i in range(len(it)):
        it[i] = i*skip_it*tau
    return it

def create_distances(Xlist, X):
    dist = np.zeros(len(Xlist))

    for i in range(len(Xlist)):
        T,_  = Xlist[i].shape
        for t in range(T):
            dist[i] += np.linalg.norm(Xlist[i][t] - X[t])**2

    return dist

def createfilename(experiment, dataset, T, omega, mu, pwork, method, backup=False):
    if backup:
        return "pickles/backup/" + experiment + "_" + dataset + "_" + str(T) + "_" + str(omega) + "_" + str(mu) + "_" + str(pwork) + "_" + method + "_" + str(time.time()) + ".pickle"
    else:
        return "pickles/" + experiment + "_" + dataset + "_" + str(T) + "_" + str(omega) + "_" + str(mu) + "_" + str(pwork) + "_" + method + ".pickle"


def load_pickle(experiment, dataset, T, omega, mu, pwork, method):
    filename = createfilename(experiment, dataset, T, omega, mu, pwork, method)
    pickle_in = open(filename, "rb")
    return pickle.load(pickle_in)


linestyles = ['-', '--', ':', '-.', '--', ':', '-.', '-']
markers = ['*', 'o', 's', 'x', 'd', 'v', 'P', '1']
colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']


def visualize(fvals, it, plot_name, dataset, labels, plots_path=os.getcwd()+'/plots/', linestyles=linestyles,
              markers=markers, colors=colors, alphas=None, ppe=0, groups=None, muT=0):
    if groups == None:
        groups = []
        groups.append(range(len(fvals)))

    eps = 1e-14
    plt.figure()
    counter = 0

    fmin = np.zeros(len(groups))
    f0 = np.zeros(len(groups))

    for i in range(len(groups)):
        group = groups[i]
        fvals_curr = [fvals[i] for i in group]
        fmin[i] = np.min([np.min(ff) for ff in fvals_curr])
        f0[i] = fvals[0][0]


        for ii in range(len(fvals_curr)):
            ff = fvals_curr[ii]
            nl = len(ff)
            nn = int(nl*2.0/5.0)
            plt.semilogy(it[counter][:nn], (ff[:nn] - fmin[i]) / (f0[i] - fmin[i]), label=labels[counter], linestyle=linestyles[counter],
                         color=colors[counter], marker=markers[counter], markevery=int(math.ceil(len(it[counter]) / 15)))


            if alphas is not None:
                rate = alphas[group[ii]]*muT
                ff = (1-rate)**(ppe*np.asarray(it[0]))
#                fvals.append(ff)
#                it.append(it[0])
                labels.append("theory")
                nl = len(ff)
                nn = int(nl*3.0/5.0)
                plt.semilogy(it[counter][:nn], ff[:nn], label="T"+labels[counter], linestyle=linestyles[counter],
                             color=colors[counter], marker="s", markevery=int(math.ceil(len(it[counter]) / 15)))

            counter += 1

    bottom, top = plt.ylim()
    plt.ylim([np.maximum(bottom, 1e-13), top])

    plt.legend(fontsize='x-large', loc='best')
    if "epoch" in plot_name:
        plt.xlabel('Data passes', fontsize='x-large')
    elif "communication" in plot_name:
        plt.xlabel('Rounds of communication', fontsize='x-large')
    else:
        plt.xlabel('Iteration', fontsize='x-large')

    plt.ylabel('Relative suboptimality', fontsize='x-large')
    plt.title('Dataset: {}'.format(dataset), fontsize='x-large')
    plt.tight_layout()
    plt.savefig('{}{}Dataset{}.pdf'.format(plots_path, plot_name, dataset))




def visualize_dist(dist_0, dist_inf, omegas_np):
    omegas_np[0] = omegas_np[1]/2

    counter = 0
    plt.plot(omegas_np, dist_0, label= r'$||x(\lambda)-x(0)||^2$' , linestyle=linestyles[counter], color=colors[counter],
                 marker=markers[counter], markevery=1)

    counter = 1
    plt.plot(omegas_np, dist_inf, label= r'$||x(\lambda)-x(\infty)||^2$', linestyle=linestyles[counter], color=colors[counter],
                 marker=markers[counter], markevery=1)

    plt.legend(fontsize='x-large', loc='best')

    plt.xlabel(r'$\lambda$', fontsize='x-large')

    plt.xlim([omegas_np[0], omegas_np[-1]*1.1])
    plt.xscale('log')

    plt.tight_layout()
    plt.savefig('{}evolution.pdf'.format(os.getcwd()+'/plots/'))

    return


def visualize_(comm, ps):

    counter = 0

    fig1, ax1 = plt.subplots()
    ax1.plot(ps, comm, label='_nolegend_', linestyle=linestyles[counter], color=colors[counter],
                 marker=markers[counter], markevery=1)
    ax1.set_xscale('log')
    ax1.set_yscale('log')

    ax1.set_xticks([0.01, 0.03, 0.1, 0.5])
    ax1.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())


    plt.xlabel(r'$p$', fontsize='x-large')

    plt.ylabel(r'Communication rounds for $\varepsilon$ solution', fontsize='x-large')
    plt.tight_layout()
    plt.savefig('{}commp2log.pdf'.format(os.getcwd()+'/plots/'))

    return