import argparse

from workers import MasterNode
from models import LinReg, LogReg, LogRegNoncvx, NN_1d_regression
from utils import read_run, get_alg, create_plot_dir, PLOT_PATH
from sklearn.datasets import dump_svmlight_file

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from prep_data import number_of_features
import math
import torch

from numpy.random import default_rng
from numpy import linalg as la
from prep_data import DATASET_PATH
import copy
from scipy.signal import savgol_filter

plt.style.use('fast')
mpl.rcParams['mathtext.fontset'] = 'cm'
# mpl.rcParams['mathtext.fontset'] = 'dejavusans'
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['lines.linewidth'] = 2.0
mpl.rcParams['legend.fontsize'] = 'large'
mpl.rcParams['axes.titlesize'] = 'xx-large'
mpl.rcParams['xtick.labelsize'] = 'x-large'
mpl.rcParams['ytick.labelsize'] = 'x-large'
mpl.rcParams['axes.labelsize'] = 'xx-large'

markers = ['x', '.', '+', '1', 'p', '*', 'D', '.', 's']
colors = [u'#1f77b4', u'#ff7f0e', u'#2ca02c', u'#d62728', u'#9467bd', u'#8c564b', u'#e377c2', u'#7f7f7f', u'#bcbd22', u'#17becf']

parser = argparse.ArgumentParser(description="Evaluate Sigma2")
parser.add_argument("--dataset", default="mushrooms", type=str)
parser.add_argument("--iid", default="clusters", type=str)
parser.add_argument("--alpha", default=1.0, type=float)
parser.add_argument("--n_workers", default=100, type=int)
parser.add_argument("--ratio", default=0.3, type=float, help="Dirichlet ratio")
parser.add_argument("--max_it", default=2500, type=int)
parser.add_argument("--clusters", default=10)
parser.add_argument("--lr", default=0.1, type=float)
parser.add_argument("--mute_print", default=False, action="store_true")
parser.add_argument("--strategy", default=1, type=int, help="Choosing strategies for minibatch SPPM")
parser.add_argument("--epsilon", default=1e-6, type=float, help="Inexact proximal tolerance")
parser.add_argument("--inexact", default=False, action="store_true")
parser.add_argument("--notrain", default=False, action="store_true")
parser.add_argument("--p", default=0.01, type=float)
parser.add_argument("--opt_iter", default=1, type=int)
parser.add_argument("--opt_criterion", default="maxiter", type=str)
parser.add_argument("--vr", default=False, action="store_true")
parser.add_argument("--local_epoch", default=5, type=int)

args = parser.parse_args()
if args.dataset == "all":
    dataset_names = ["mushrooms", "ijcnn1.bz2", "a6a", "w6a"]
else:
    dataset_names = [args.dataset]

ratio = args.ratio
max_it = args.max_it
n_iter_shown = 3000

# Calculate a varying standard deviation for each point
# This could be a moving standard deviation or based on some model of expected variance
# Here we use a simple moving standard deviation for demonstration purposes
def moving_stddev(data, window_size):
    # Pad data to handle the borders
    data_padded = np.pad(data, (window_size // 2, window_size - 1 - window_size // 2), mode='edge')
    # Calculate moving standard deviation
    std_devs = np.array([np.std(data_padded[i:i + window_size]) for i in range(len(data))])
    return std_devs
window_length = 5  # Window length should be odd
window_length2 = 5  # Window length should be odd
polyorder = 3  # The polynomial order


labels = ['SPPM', 'MB-SPPM', 'SVRP', 'VR-MBSPPM', 'MB-GD', 'MB-LocalGD']
# dataset_names = ["mushrooms", "ijcnn1.bz2", "a6a", "w6a"]
# for dataset_name in dataset_names:
# args.dataset = dataset_name
# Plot w.r.t. communication rounds
fig, axs = plt.subplots(2, figsize=(5, 7), constrained_layout=True)

######################## Beaselines ####################################
######################## Beaselines ####################################
######################## Beaselines ####################################
name = f"baselines_localsgd_{args.iid}_{args.n_workers}_{args.local_epoch}_{args.lr}"
exp_localsgd = f'localsgd_{name}'
alg = LogReg
logreg = True

window_length = 5  # Window length should be odd
# dataset_names = ["mushrooms", "ijcnn1.bz2", "a6a", "w6a"]
# for dataset_name in dataset_names:
# args.dataset = dataset_name
if not args.notrain:
    print('------------------- alpha = {} --------------------'.format(args.alpha))
    model = MasterNode(args.n_workers, args.iid, args.ratio, args.alpha, alg, args.dataset, logreg, True, 150,
                       cluster=args.clusters, regularization=0.1)
    print('Running FedAvg...')
    # w_sgd = model.run_sgd(n_iter=args.max_it, mb_size=10, exp_name=exp_sgd)
    w_localsgd = model.run_localsgd(n_iter=args.max_it, mb_size=10, local_epoch=args.local_epoch,
                                    exp_name=exp_localsgd, lr=args.lr)

# run_sgd = read_run(exp_sgd, [args.alpha] * args.n_workers, args.dataset, logreg)
run_localsgd = read_run(exp_localsgd, [args.alpha] * args.n_workers, args.dataset, logreg)


fvals = run_localsgd['fval'][:n_iter_shown]
# dists = run_minibatch_sppm['dist'][:n_iter_shown]
gnorms = run_localsgd['grad'][:n_iter_shown]
print("run_localsgd", fvals, gnorms)
markevery = int(fvals.size / 20)
ind = 5
# axs[0].plot(fvals, marker=markers[ind], markevery=(markevery + 2 * ind, markevery),
#             markersize=10, label=labels[ind], color=colors[ind])
# # axs[1].plot(dists, marker=markers[ind], markevery=(markevery + 2 * ind, markevery), markersize=10)
# axs[1].plot(gnorms, marker=markers[ind], markevery=(markevery + 2 * ind, markevery),
#             markersize=10, label=labels[ind], color=colors[ind])
tmp_x, tmp_y = np.arange(len(fvals)), fvals
y_smooth = savgol_filter(tmp_y, window_length, polyorder)
std_deviation = moving_stddev(tmp_y - y_smooth, window_length)
upper_bound = y_smooth + std_deviation / 100
lower_bound = y_smooth - std_deviation / 100
# axs[0].plot(fvals, marker=markers[ind], markevery=(markevery + 2 * ind, markevery),
#             markersize=10, label=labels[ind], color=colors[ind])
# # axs[1].plot(dists, marker=markers[ind], markevery=(markevery + 2 * ind, markevery), markersize=10)
# axs[1].plot(gnorms, marker=markers[ind], markevery=(markevery + 2 * ind, markevery),
#             markersize=10, label=labels[ind], color=colors[ind])
axs[0].plot(tmp_x, y_smooth, marker=markers[ind], markevery=(markevery + 2 * ind, markevery),
            markersize=10, label=labels[ind], color=colors[ind])
axs[0].fill_between(tmp_x, lower_bound, upper_bound, color=colors[ind], alpha=0.2)
# axs[1].plot(dists, marker=markers[ind], markevery=(markevery + 2 * ind, markevery), markersize=10)


# Apply Savitzky-Golay filter with a larger window to the initial values
initial_window_length = 5  # Choose an odd number, larger for more smoothing
initial_polyorder = 3
initial_points = 10  # Number of points to smooth at the beginning
# Smoothing for the rest of the data
initial_smooth = savgol_filter(gnorms[:initial_points], initial_window_length, initial_polyorder)
rest_smooth = savgol_filter(gnorms[initial_points:], window_length2, polyorder)
# Combine the two segments
y_smooth_combined = np.concatenate((initial_smooth, rest_smooth))
tmp_x, tmp_y = np.arange(len(gnorms)), gnorms
std_deviation = moving_stddev(tmp_y - y_smooth_combined, window_length)
upper_bound = y_smooth_combined + std_deviation / 100
lower_bound = y_smooth_combined - std_deviation / 100
# axs[1].plot(gnorms, marker=markers[ind], markevery=(markevery + 2 * ind, markevery),
#             markersize=10, label=labels[ind], color=colors[ind])
axs[1].plot(tmp_x, y_smooth_combined, marker=markers[ind], markevery=(markevery + 2 * ind, markevery),
            markersize=10, label=labels[ind], color=colors[ind])
axs[1].fill_between(tmp_x, lower_bound, upper_bound, color=colors[ind], alpha=0.2)
# axs[1].plot(dists, marker=markers[ind], markevery=(markevery + 2 * ind, markevery), markersize=10)


# print(len(gnorms), gnorms)
axs[0].legend()
axs[0].set_yscale('log')
axs[1].set_yscale('log')
axs[1].set_xlabel('Communication rounds')
# axs[0].set_ylabel('Squared distance')
# axs[0].set_ylabel(r'$\|f(x)-f^{\star}\|^2$')
axs[0].set_ylabel(r'$f(x)-f^{\star}$')
# axs[1].set_ylabel('Loss')
axs[1].set_ylabel(r'$\|\nabla f(x)- \nabla f(x^{\star})\|^2$')

axs[0].set_title(args.dataset)
alg = get_alg(logreg)
# name = exp + '_' + alg + '_' + args.dataset
# name = alg + '_' + args.dataset
create_plot_dir()
plt.savefig(PLOT_PATH + '/' + name + '_minibatch_sppm_' + 'general_baselines' + '.pdf')
plt.show()
