import argparse

from workers import MasterNode
from models import LinReg, LogReg, LogRegNoncvx, NN_1d_regression
from utils import read_run, get_alg, create_plot_dir, PLOT_PATH
from sklearn.datasets import dump_svmlight_file

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from prep_data import number_of_features
import math
import torch

from numpy.random import default_rng
from numpy import linalg as la
from prep_data import DATASET_PATH
import copy

plt.style.use('fast')
mpl.rcParams['mathtext.fontset'] = 'cm'
# mpl.rcParams['mathtext.fontset'] = 'dejavusans'
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['lines.linewidth'] = 2.0
mpl.rcParams['legend.fontsize'] = 'large'
mpl.rcParams['axes.titlesize'] = 'xx-large'
mpl.rcParams['xtick.labelsize'] = 'x-large'
mpl.rcParams['ytick.labelsize'] = 'x-large'
mpl.rcParams['axes.labelsize'] = 'xx-large'

markers = ['x', '.', '+', '1', 'p', '*', 'D', '.', 's']
colors = [u'#1f77b4', u'#ff7f0e', u'#2ca02c', u'#d62728', u'#9467bd', u'#8c564b', u'#e377c2', u'#7f7f7f', u'#bcbd22', u'#17becf']

parser = argparse.ArgumentParser(description="Evaluate Sigma2")
parser.add_argument("--dataset", default="a6a", type=str)
parser.add_argument("--iid", default="clusters", type=str)
parser.add_argument("--alpha", default=1.0, type=float)
parser.add_argument("--n_workers", default=100, type=int)
parser.add_argument("--ratio", default=0.3, type=float, help="Dirichlet ratio")
parser.add_argument("--max_it", default=2500, type=int)
parser.add_argument("--clusters", default=10)
parser.add_argument("--lr", default=0.1, type=float)
parser.add_argument("--mute_print", default=False, action="store_true")
parser.add_argument("--strategy", default=1, type=int, help="Choosing strategies for minibatch SPPM")
parser.add_argument("--epsilon", default=1e-6, type=float, help="Inexact proximal tolerance")
parser.add_argument("--inexact", default=False, action="store_true")
parser.add_argument("--notrain", default=False, action="store_true")
parser.add_argument("--opt_iter", default=5, type=int)
parser.add_argument("--opt_criterion", default="epsilon", type=str)
parser.add_argument("--opt_method", default="BFGS", type=str, help="BFGS, CG, Newton-CG")
parser.add_argument("--vr", default=False, action="store_true")

args = parser.parse_args()
if args.dataset == "all":
    dataset_names = ["mushrooms", "ijcnn1.bz2", "a6a", "w6a"]
else:
    dataset_names = [args.dataset]

ratio = args.ratio
max_it = args.max_it
# name = f"{args.iid}_{args.n_workers}_{args.dataset}_{args.inexact}_{args.epsilon}_{args.opt_iter}_{args.lr}"
name = f"{args.opt_method}_{args.iid}_{args.n_workers}_{args.dataset}_{args.vr}_{args.inexact}_{args.epsilon}_{args.opt_iter}_{args.lr}"
exp_sppm = f'sppm_{name}'
exp_minibatch_sppm = f'minibatch_sppm_{name}'
inexact_exp_sppm = f'inexact_sppm_{name}'
inexact_exp_minibatch_sppm = f'inexact_minibatch_sppm_{name}'
alg = LogReg
logreg = True

labels = ['SPPM', 'minibatch-SPPM-s1', 'minibatch-SPPM-s2']
# dataset_names = ["mushrooms", "ijcnn1.bz2", "a6a", "w6a"]
# for dataset_name in dataset_names:
# args.dataset = dataset_name
if not args.notrain:
    print('------------------- alpha = {} --------------------'.format(args.alpha))
    model = MasterNode(args.n_workers, args.iid, args.ratio, args.alpha, alg, args.dataset, logreg, True, 500,
                       cluster=args.clusters, regularization=0.1)
    print('Running SPPM...')
    # w = model.sppm(lr=args.lr, n_iter=args.max_it)
    if args.inexact:
        if args.vr:
            w_sppm = model.vr_sppm(lr=args.lr, n_iter=args.max_it, epsilon=args.epsilon, exp_name=inexact_exp_sppm,
                                   p=args.p)
            w_minibatch_sppm = model.vr_minibatch_sppm(lr=args.lr, n_iter=args.max_it, strategy=1,
                                                       epsilon=args.epsilon, exp_name=inexact_exp_minibatch_sppm,
                                                       p=args.p)
            w_minibatch_sppm2 = model.vr_minibatch_sppm(lr=args.lr, n_iter=args.max_it, strategy=2,
                                                        epsilon=args.epsilon, exp_name=inexact_exp_minibatch_sppm,
                                                        p=args.p)
        else:
            w_sppm = model.sppm(lr=args.lr, n_iter=args.max_it, epsilon=args.epsilon, exp_name=inexact_exp_sppm,
                                opt_iter=args.opt_iter, opt_criterion=args.opt_criterion, opt_method=args.opt_method)
            w_minibatch_sppm = model.minibatch_sppm(lr=args.lr, n_iter=args.max_it, strategy=1,epsilon=args.epsilon,
                                                    exp_name=inexact_exp_minibatch_sppm, opt_iter=args.opt_iter,
                                                    opt_criterion=args.opt_criterion, opt_method=args.opt_method)
            w_minibatch_sppm2 = model.minibatch_sppm(lr=args.lr, n_iter=args.max_it, strategy=2,epsilon=args.epsilon,
                                                     exp_name=inexact_exp_minibatch_sppm, opt_iter=args.opt_iter,
                                                     opt_criterion=args.opt_criterion, opt_method=args.opt_method)
    else:
        if args.vr:
            w_sppm = model.vr_sppm(lr=args.lr, n_iter=args.max_it, exp_name=exp_sppm, p=args.p)
            w_minibatch_sppm = model.vr_minibatch_sppm(lr=args.lr, n_iter=args.max_it, strategy=1,
                                                       exp_name=exp_minibatch_sppm, p=args.p)
            w_minibatch_sppm2 = model.vr_minibatch_sppm(lr=args.lr, n_iter=args.max_it, strategy=2,
                                                        exp_name=exp_minibatch_sppm, p=args.p)
        else:
            w_sppm = model.sppm(lr=args.lr, n_iter=args.max_it, exp_name=exp_sppm)
            w_minibatch_sppm = model.minibatch_sppm(lr=args.lr, n_iter=args.max_it, strategy=1, exp_name=exp_minibatch_sppm)
            w_minibatch_sppm2 = model.minibatch_sppm(lr=args.lr, n_iter=args.max_it, strategy=2, exp_name=exp_minibatch_sppm)

n_iter_shown = 5000

if args.inexact:
    run_sppm = read_run(inexact_exp_sppm, [args.alpha] * args.n_workers, args.dataset, logreg)
    run_minibatch_sppm = read_run(inexact_exp_minibatch_sppm, [args.alpha] * args.n_workers, args.dataset, logreg, strategy=1)
    run_minibatch_sppm2 = read_run(inexact_exp_minibatch_sppm, [args.alpha] * args.n_workers, args.dataset, logreg, strategy=2)
else:
    run_sppm = read_run(exp_sppm, [args.alpha] * args.n_workers, args.dataset, logreg)
    run_minibatch_sppm = read_run(exp_minibatch_sppm, [args.alpha] * args.n_workers, args.dataset, logreg, strategy=1)
    run_minibatch_sppm2 = read_run(exp_minibatch_sppm, [args.alpha] * args.n_workers, args.dataset, logreg, strategy=2)

fvals = run_sppm['fval'][:n_iter_shown]
# dists = run_sppm['dist'][:n_iter_shown]
gnorms = run_sppm['grad'][:n_iter_shown]
print("sppm", fvals, gnorms)
markevery = int(fvals.size / 20)

# Plot w.r.t. communication rounds
fig, axs = plt.subplots(2, figsize=(5, 7), constrained_layout=True)
ind = 0
axs[0].plot(fvals, marker=markers[ind], markevery=(markevery + 2 * ind, markevery),
            markersize=10, label=labels[ind], color=colors[ind])
# axs[1].plot(dists, marker=markers[ind], markevery=(markevery + 2 * ind, markevery), markersize=10)
axs[1].plot(gnorms, marker=markers[ind], markevery=(markevery + 2 * ind, markevery),
            markersize=10, label=labels[ind], color=colors[ind])

fvals = run_minibatch_sppm['fval'][:n_iter_shown]
# dists = run_minibatch_sppm['dist'][:n_iter_shown]
gnorms = run_minibatch_sppm['grad'][:n_iter_shown]
print("minibatch_sppm", fvals, gnorms)
ind = 1
axs[0].plot(fvals, marker=markers[ind], markevery=(markevery + 2 * ind, markevery),
            markersize=10, label=labels[ind], color=colors[ind])
# axs[1].plot(dists, marker=markers[ind], markevery=(markevery + 2 * ind, markevery), markersize=10)
axs[1].plot(gnorms, marker=markers[ind], markevery=(markevery + 2 * ind, markevery),
            markersize=10, label=labels[ind], color=colors[ind])


fvals = run_minibatch_sppm2['fval'][:n_iter_shown]
# dists = run_minibatch_sppm['dist'][:n_iter_shown]
gnorms = run_minibatch_sppm2['grad'][:n_iter_shown]
print("minibatch_sppm2", fvals, gnorms)
ind = 2
axs[0].plot(fvals, marker=markers[ind], markevery=(markevery + 2 * ind, markevery),
            markersize=10, label=labels[ind], color=colors[ind])
# axs[1].plot(dists, marker=markers[ind], markevery=(markevery + 2 * ind, markevery), markersize=10)
axs[1].plot(gnorms, marker=markers[ind], markevery=(markevery + 2 * ind, markevery),
            markersize=10, label=labels[ind], color=colors[ind])


# print(len(gnorms), gnorms)
axs[0].legend()
axs[0].set_yscale('log')
axs[1].set_yscale('log')
axs[1].set_xlabel('Communication rounds')
# axs[0].set_ylabel('Squared distance')
# axs[0].set_ylabel(r'$\|f(x)-f^{\star}\|^2$')
axs[0].set_ylabel(r'$f(x)-f^{\star}$')
# axs[1].set_ylabel('Loss')
axs[1].set_ylabel(r'$\|\nabla f(x)- \nabla f(x^{\star})\|^2$')

axs[0].set_title(args.dataset)
alg = get_alg(logreg)
# name = exp + '_' + alg + '_' + args.dataset
# name = alg + '_' + args.dataset
create_plot_dir()
plt.savefig(PLOT_PATH + '/' + name + '_minibatch_sppm_' + 'round_' + '.pdf')
plt.show()


#
# # Plot w.r.t. communication costs: (1+tau*alpha), beta=0.5
# fvals = run_sppm['fval'][:n_iter_shown]
# # dists = run_sppm['dist'][:n_iter_shown]
# gnorms = run_sppm['grad'][:n_iter_shown]

# beta = 0.5
# x_axis = np.arange(len(fvals))
# max_index = x_axis.index(max(x_axis))
# x_axis_cost = np.arange(len(fvals)) * (1 + beta * int(args.n_workers/args.n_clusters))
# x_axis_2 = x_axis_cost[:max_index + 1]
# fig, axs = plt.subplots(2, figsize=(7, 10), constrained_layout=True)
# ind = 0
# axs[0].plot(fvals, marker=markers[ind], markevery=(markevery + 2 * ind, markevery),
#             markersize=10, label=labels[ind], color=colors[ind])
# # axs[1].plot(dists, marker=markers[ind], markevery=(markevery + 2 * ind, markevery), markersize=10)
# axs[1].plot(gnorms, marker=markers[ind], markevery=(markevery + 2 * ind, markevery),
#             markersize=10, label=labels[ind], color=colors[ind])
#
# fvals = run_minibatch_sppm['fval'][:n_iter_shown]
# # dists = run_minibatch_sppm['dist'][:n_iter_shown]
# gnorms = run_minibatch_sppm['grad'][:n_iter_shown]
#
# ind = 1
# axs[0].plot(fvals, marker=markers[ind], markevery=(markevery + 2 * ind, markevery),
#             markersize=10, label=labels[ind], color=colors[ind])
# # axs[1].plot(dists, marker=markers[ind], markevery=(markevery + 2 * ind, markevery), markersize=10)
# axs[1].plot(gnorms, marker=markers[ind], markevery=(markevery + 2 * ind, markevery),
#             markersize=10, label=labels[ind], color=colors[ind])
#
# print(len(gnorms), gnorms)
# axs[0].legend()
# axs[0].set_yscale('log')
# axs[1].set_yscale('log')
# axs[1].set_xlabel('Communication rounds')
# # axs[0].set_ylabel('Squared distance')
# # axs[0].set_ylabel(r'$\|f(x)-f^{\star}\|^2$')
# axs[0].set_ylabel(r'$f(x)-f^{\star}$')
# # axs[1].set_ylabel('Loss')
# axs[1].set_ylabel(r'$\|\nabla f(x)- \nabla f(x^{\star})\|^2$')
#
# axs[0].set_title(args.dataset)
# alg = get_alg(logreg)
# # name = exp + '_' + alg + '_' + args.dataset
# name = alg + '_' + args.dataset
# create_plot_dir()
# plt.savefig(PLOT_PATH + '/' + name + '_minibatch_sppm_' + 'cost' + str(beta) + '.pdf')
# plt.show()