import argparse

from workers import MasterNode
from models import LinReg, LogReg, LogRegNoncvx, NN_1d_regression
from utils import read_run, get_alg, create_plot_dir, PLOT_PATH
from sklearn.datasets import dump_svmlight_file

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from prep_data import number_of_features
import math
import torch

from numpy.random import default_rng
from numpy import linalg as la
from prep_data import DATASET_PATH
import copy

plt.style.use('fast')
mpl.rcParams['mathtext.fontset'] = 'cm'
# mpl.rcParams['mathtext.fontset'] = 'dejavusans'
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['lines.linewidth'] = 2.0
mpl.rcParams['legend.fontsize'] = 'large'
mpl.rcParams['axes.titlesize'] = 'xx-large'
mpl.rcParams['xtick.labelsize'] = 'x-large'
mpl.rcParams['ytick.labelsize'] = 'x-large'
mpl.rcParams['axes.labelsize'] = 'xx-large'

markers = ['x', '.', '+', '1', 'p', '*', 'D', '.', 's']
colors = [u'#1f77b4', u'#ff7f0e', u'#2ca02c', u'#d62728', u'#9467bd', u'#8c564b', u'#e377c2', u'#7f7f7f', u'#bcbd22', u'#17becf']

parser = argparse.ArgumentParser(description="Evaluate Sigma2")
parser.add_argument("--dataset", default="mushrooms", type=str)
parser.add_argument("--iid", default="clusters", type=str)
parser.add_argument("--alpha", default=1.0, type=float)
parser.add_argument("--n_workers", default=100, type=int)
parser.add_argument("--ratio", default=0.3, type=float, help="Dirichlet ratio")
parser.add_argument("--max_it", default=1200, type=int)
parser.add_argument("--clusters", default=10)
parser.add_argument("--lr", default=0.1, type=float)
parser.add_argument("--mute_print", default=False, action="store_true")
parser.add_argument("--strategy", default=1, type=int, help="Choosing strategies for minibatch SPPM")
parser.add_argument("--epsilon", default=1e-6, type=float, help="Inexact proximal tolerance")
parser.add_argument("--inexact", default=False, action="store_true")
parser.add_argument("--notrain", default=False, action="store_true")

args = parser.parse_args()
if args.dataset == "all":
    dataset_names = ["mushrooms", "ijcnn1.bz2", "a6a", "w6a"]
else:
    dataset_names = [args.dataset]

ratio = args.ratio
max_it = args.max_it
epsilons = [0.1, 0.01, 0.001, 0.0001, 0.0]
methods = ['SPPM', 'minibatch-SPPM-s1', 'minibatch-SPPM-s2']
# We should consider another exact
labels = [1e-1, 1e-2, 1e-3, 1e-4, 'exact']
for method in methods:
    ind = 0
    # Plot w.r.t. communication rounds
    fig, axs = plt.subplots(2, figsize=(5, 7), constrained_layout=True)
    for epsilon in epsilons:
        name = f"{args.iid}_{args.n_workers}_{args.dataset}_{args.inexact}_{epsilon}_{args.lr}"
        exp_sppm = f'sppm_{name}'
        exp_minibatch_sppm = f'minibatch_sppm_{name}'
        inexact_exp_sppm = f'inexact_sppm_{name}'
        inexact_exp_minibatch_sppm = f'inexact_minibatch_sppm_{name}'
        alg = LogReg
        logreg = True

        # dataset_names = ["mushrooms", "ijcnn1.bz2", "a6a", "w6a"]
        # for dataset_name in dataset_names:
        # args.dataset = dataset_name
        n_iter_shown = 5000

        if args.inexact:
            run_sppm = read_run(inexact_exp_sppm, [args.alpha] * args.n_workers, args.dataset, logreg)
            run_minibatch_sppm = read_run(inexact_exp_minibatch_sppm, [args.alpha] * args.n_workers, args.dataset, logreg, strategy=1)
            run_minibatch_sppm2 = read_run(inexact_exp_minibatch_sppm, [args.alpha] * args.n_workers, args.dataset, logreg, strategy=2)
        else:
            run_sppm = read_run(exp_sppm, [args.alpha] * args.n_workers, args.dataset, logreg)
            run_minibatch_sppm = read_run(exp_minibatch_sppm, [args.alpha] * args.n_workers, args.dataset, logreg, strategy=1)
            run_minibatch_sppm2 = read_run(exp_minibatch_sppm, [args.alpha] * args.n_workers, args.dataset, logreg, strategy=2)

        if method == "SPPM":
            fvals = run_sppm['fval'][:n_iter_shown]
            # dists = run_sppm['dist'][:n_iter_shown]
            gnorms = run_sppm['grad'][:n_iter_shown]
            # print("sppm", fvals, gnorms)
        elif method == "minibatch-SPPM-s1":
            fvals = run_minibatch_sppm['fval'][:n_iter_shown]
            # dists = run_minibatch_sppm['dist'][:n_iter_shown]
            gnorms = run_minibatch_sppm['grad'][:n_iter_shown]
            # print("minibatch_sppm", fvals, gnorms)
        elif method == "minibatch-SPPM-s2":
            fvals = run_minibatch_sppm2['fval'][:n_iter_shown]
            # dists = run_minibatch_sppm['dist'][:n_iter_shown]
            gnorms = run_minibatch_sppm2['grad'][:n_iter_shown]
            # print("minibatch_sppm2", fvals, gnorms)

        markevery = int(fvals.size / 20)

        axs[0].plot(fvals, marker=markers[ind], markevery=(markevery + 2 * ind, markevery),
                    markersize=10, label=labels[ind], color=colors[ind])
        # axs[1].plot(dists, marker=markers[ind], markevery=(markevery + 2 * ind, markevery), markersize=10)
        axs[1].plot(gnorms, marker=markers[ind], markevery=(markevery + 2 * ind, markevery),
                    markersize=10, label=labels[ind], color=colors[ind])
        ind = ind + 1

    axs[0].legend()
    axs[0].set_yscale('log')
    axs[1].set_yscale('log')
    axs[1].set_xlabel('Communication rounds')
    # axs[0].set_ylabel('Squared distance')
    # axs[0].set_ylabel(r'$\|f(x)-f^{\star}\|^2$')
    axs[0].set_ylabel(r'$f(x)-f^{\star}$')
    # axs[1].set_ylabel('Loss')
    axs[1].set_ylabel(r'$\|\nabla f(x)- \nabla f(x^{\star})\|^2$')

    axs[0].set_title(args.dataset)
    alg = get_alg(logreg)

    create_plot_dir()
    plt.savefig(PLOT_PATH + '/' + name + method + '_minibatch_sppm_' + 'round_epsilon_' + '.pdf')
    plt.show()