import argparse

from workers import MasterNode
from models import LinReg, LogReg, LogRegNoncvx, NN_1d_regression
from utils import read_run, get_alg, create_plot_dir, PLOT_PATH
from sklearn.datasets import dump_svmlight_file

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from prep_data import number_of_features
import math
import torch

from numpy.random import default_rng
from numpy import linalg as la
from prep_data import DATASET_PATH
import copy

plt.style.use('fast')
mpl.rcParams['mathtext.fontset'] = 'cm'
# mpl.rcParams['mathtext.fontset'] = 'dejavusans'
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['lines.linewidth'] = 2.0
mpl.rcParams['legend.fontsize'] = 'large'
mpl.rcParams['axes.titlesize'] = 'xx-large'
mpl.rcParams['xtick.labelsize'] = 'x-large'
mpl.rcParams['ytick.labelsize'] = 'x-large'
mpl.rcParams['axes.labelsize'] = 'xx-large'

markers = ['x', '.', '+', '1', 'p', '*', 'D', '.', 's']
colors = [u'#1f77b4', u'#ff7f0e', u'#2ca02c', u'#d62728', u'#9467bd', u'#8c564b', u'#e377c2', u'#7f7f7f', u'#bcbd22', u'#17becf']

parser = argparse.ArgumentParser(description="Evaluate Sigma2")
parser.add_argument("--dataset", default="mushrooms", type=str)
parser.add_argument("--iid", default="clusters", type=str)
parser.add_argument("--alpha", default=1.0, type=float)
parser.add_argument("--n_workers", default=100, type=int)
parser.add_argument("--ratio", default=0.3, type=float, help="Dirichlet ratio")
parser.add_argument("--max_it", default=2500, type=int)
parser.add_argument("--clusters", default=10)
parser.add_argument("--lr", default=0.1, type=float)
parser.add_argument("--mute_print", default=False, action="store_true")
parser.add_argument("--strategy", default=1, type=int, help="Choosing strategies for minibatch SPPM")
parser.add_argument("--epsilon", default=1e-6, type=float, help="Inexact proximal tolerance")
parser.add_argument("--inexact", default=False, action="store_true")
parser.add_argument("--notrain", default=False, action="store_true")
parser.add_argument("--opt_iter", default=5, type=int)
parser.add_argument("--opt_criterion", default="epsilon", type=str)
parser.add_argument("--vr", default=False, action="store_true")

args = parser.parse_args()
if args.dataset == "all":
    dataset_names = ["mushrooms", "ijcnn1.bz2", "a6a", "w6a"]
else:
    dataset_names = [args.dataset]

ratio = args.ratio
max_it = args.max_it
# name = f"{args.iid}_{args.n_workers}_{args.dataset}_{args.inexact}_{args.epsilon}_{args.opt_iter}_{args.lr}"
name = f"baselines_{args.iid}_{args.n_workers}_{args.dataset}_{args.vr}_{args.inexact}_{args.epsilon}_{args.opt_iter}_{args.lr}"
exp_sgd = f'sgd_{name}'
exp_localsgd = f'localsgd_{name}'
alg = LogReg
logreg = True

# labels = ['minibatch GD', 'minibatch LocalGD']
labels = [1, 2, 4, 8, 16, 32, 64, 128, 256]
# Plot w.r.t. communication rounds
fig, axs = plt.subplots(2, figsize=(5, 7), constrained_layout=True)

for label, ind in enumerate(labels):
    exp_localsgd = f'localsgd_{name}_{label}'
    # dataset_names = ["mushrooms", "ijcnn1.bz2", "a6a", "w6a"]
    # for dataset_name in dataset_names:
    # args.dataset = dataset_name
    if not args.notrain:
        print('------------------- alpha = {} --------------------'.format(args.alpha))
        model = MasterNode(args.n_workers, args.iid, args.ratio, args.alpha, alg, args.dataset, logreg, True, 500,
                           cluster=args.clusters, regularization=0.1)
        print('Running SPPM...')
        # w = model.sppm(lr=args.lr, n_iter=args.max_it)
        w_localsgd = model.run_localsgd(n_iter=args.max_it, mb_size=10, local_epoch=int(label), exp_name=exp_localsgd)

    n_iter_shown = 5000

    run_localsgd = read_run(exp_localsgd, [args.alpha] * args.n_workers, args.dataset, logreg)

    fvals = run_localsgd['fval'][:n_iter_shown]
    # dists = run_minibatch_sppm['dist'][:n_iter_shown]
    gnorms = run_localsgd['grad'][:n_iter_shown]
    markevery = int(fvals.size / 20)

    # print("minibatch_sppm", fvals, gnorms)
    axs[0].plot(fvals, marker=markers[ind], markevery=(markevery + 2 * ind, markevery),
                markersize=10, label=labels[ind], color=colors[ind])
    # axs[1].plot(dists, marker=markers[ind], markevery=(markevery + 2 * ind, markevery), markersize=10)
    axs[1].plot(gnorms, marker=markers[ind], markevery=(markevery + 2 * ind, markevery),
                markersize=10, label=labels[ind], color=colors[ind])


# print(len(gnorms), gnorms)
axs[0].legend()
axs[0].set_yscale('log')
axs[1].set_yscale('log')
axs[1].set_xlabel('Communication rounds')
# axs[0].set_ylabel('Squared distance')
# axs[0].set_ylabel(r'$\|f(x)-f^{\star}\|^2$')
axs[0].set_ylabel(r'$f(x)-f^{\star}$')
# axs[1].set_ylabel('Loss')
axs[1].set_ylabel(r'$\|\nabla f(x)- \nabla f(x^{\star})\|^2$')

axs[0].set_title(args.dataset)
alg = get_alg(logreg)
# name = exp + '_' + alg + '_' + args.dataset
# name = alg + '_' + args.dataset
create_plot_dir()
plt.savefig(PLOT_PATH + '/' + name + '_minibatch_sppm_' + 'round_localsteps' + '.pdf')
plt.show()


