import numpy as np

from optimizer import *
from FedAvg import *
from WeightedFedAvg import *
from FedProx import *
from local_training import *
from params import args_parser
import csv
import matplotlib
import matplotlib.pylab as plt
import matplotlib.pylab as pylab
import seaborn as sns
sns.set_style('whitegrid')
sns.set_palette("bright")


train_data_dir = './data/'
test_data_dir = './data/'

args = args_parser()

lr = args.lr
bs = args.bs
glr = 1
random_cp = False
cp = args.cp
bo_cp = args.bo_cp
total_rnd = args.total_rnd

data_csv = {}
data_csv2 = {}
test_acc_global = {}
test_loss_global = {}
pos_test_loss_global = {}
train_loss_global = {}
parti_rate_global = {}
eta_global = {}
weights_algo = {}

test_acc_parti_global= {}
test_loss_parti_global = {}

unseen_test_acc_global = {}
unseen_test_loss_global = {}
unseen_pos_test_loss_global = {}
unseen_train_loss_global = {}
unseen_parti_rate_global = {}
unseen_test_acc_parti_global= {}
unseen_test_loss_parti_global = {}

algs = {'fedavg': ('FedAvg', 'C0', '-.', None, None),
        'sigm': ('IncFL', 'C2', '-', None, None),
        'fedprox': ('FedProx', 'C4', ':','x', 70),
        'scaffold': ('SCAFFOLD', 'C1', '-.', 'o',50),
        'perfedavg': ('PerFedAvg', 'C5', '-.','x', 70),
        'ditto': ('Ditto', 'C6', '-','x', 70),
'qFFL': ('qFFL', 'C7', '-','x', 70),
'MW-Fed': ('MW-Fed', 'C8', '-','x', 70),
        'stay': ('stay', 'C3', '--', '+',70)
}

run_algs_keys = ['MW-Fed']
print('     Starting Local Training')
alg = None
local_opt = local_train(args, None, lr, bs, glr, train_data_dir, test_data_dir,
                        random_cp, bo_cp, cp, sample_ratio = args.sample_ratio, etamu=0)

local_losses, val_losses, local_models, better_local_models = local_opt.burn_out_local_update()


unseen_local_opt = local_train(args, None, lr, bs, glr, './unseendata/', './unseendata/',
                        random_cp, bo_cp, cp, sample_ratio = args.sample_ratio, etamu=0)
unseen_local_losses, unseen_val_losses, unseen_local_models, unseen_better_local_models = unseen_local_opt.burn_out_local_update()


ftsize = 35
littleft = 25
params = {'legend.fontsize': ftsize,
              'axes.labelsize': ftsize,
              'axes.titlesize': ftsize,
              'xtick.labelsize': littleft,
              'ytick.labelsize': littleft}

matplotlib.rcParams['pdf.fonttype']=42
matplotlib.rcParams['ps.fonttype']=42
pylab.rcParams.update(params)
lw = 3
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["legend.handlelength"] = 1



for key in run_algs_keys:
    np.random.seed(12345)
    alg, color, lstyle, mark, marker_size = algs[key]
    print('     Starting ' + alg)

    opt = FedAvg(args, key, lr, bs, glr, train_data_dir, test_data_dir, random_cp, bo_cp, cp, sample_ratio = args.sample_ratio,
                 etamu=0, val_losses = val_losses,
                 local_models= local_models, unseen_val_losses = unseen_val_losses, unseen_local_models = unseen_local_models)

    if key == 'stay':
        key = 'Leave'
    elif key == 'sigm':
        key = 'IncFL'
    elif key == 'fedavg':
        key = 'FedAvg'
    elif key == 'scaffold':
        key = 'Scaffold'
    elif key == 'fedprox':
        key = 'FedProx'

    errors, pr_errors, prs, losses, val_lossis = list(), list(), list(), list(), list()
    accs = []
    all_accs = []
    all_losses = list()

    unseen_accs, unseen_all_accs, unseen_all_losses, unseen_prs = [], [], [], []



    for rnd in range(total_rnd):
        if alg == 'MW-Fed' or alg == 'qFFL':
            Delta = opt.local_update_other()
        else:
            Delta = opt.local_update()
        opt.aggregate(Delta)

        pr, pr_cli, unseen_pr, unseen_pr_cli = opt.evaluate_pr()      # evaluate participation over validation loss
        loss, all_loss, acc, all_acc = opt.evaluate_test()  # evaluate test loss over respective model
        unseen_loss, unseen_all_loss, unseen_acc, unseen_all_acc = opt.evaluate_unseentest(unseen_pr_cli)
        val_loss, diff = opt.evaluate_val()  # evaluate test loss over respective model
        error = opt.evaluate()      # training losses over the server model
        pr_error = opt.evaluate_prtrain()  # training losses over the server model

        errors.append(error)
        pr_errors.append(pr_error)
        prs.append(pr)
        losses.append(loss)
        val_lossis.append(val_loss)
        all_losses.append(all_loss)
        all_accs.append(all_acc)
        accs.append(acc)

        unseen_all_accs.append(unseen_all_acc)
        unseen_accs.append(unseen_acc)
        unseen_prs.append(unseen_pr)
        unseen_all_losses.append(unseen_all_loss)


    data_csv[alg + '_test'] = all_accs
    data_csv[alg + '_trainloss'] = all_losses
    data_csv[alg + '_partest'] = accs
    data_csv[alg + '_parti'] = prs

    data_csv2[alg + '_test'] = unseen_all_accs
    data_csv2[alg + '_trainloss'] = unseen_all_losses
    data_csv2[alg + '_partest'] = unseen_accs
    data_csv2[alg + '_parti'] = unseen_prs

    if args.local_tune == 1:         # evaluate with local models

        pr_local, loss_local, local_all_loss, pr_acc = opt.local_tuning()       # compare locally tuned model with the local model
        unseen_pr_local, unseen_loss_local, unseen_local_all_loss, unseen_pr_acc = opt.unseen_local_tuning()
        print('-----------> {} Local Tuned: PR {} parloss{} paracc{} loss{}'.format(key, pr_local, loss_local, pr_acc,  local_all_loss))
        print('-----------> {} Unseen Local Tuned: PR {} parloss{} paracc{} loss{}'.format(key, unseen_pr_local, unseen_loss_local, unseen_pr_acc,
                                                                                    unseen_local_all_loss))


file_name = "data"+args.hetero+"/mwf_leave{}_nu{}_sd{}_opt{}, trainr{}, valr{}, testr{}, spr{}, eta{}, b{}, binc{}, " \
                               ", q{}, c{}, tau{}, Tl={}, Ts={}, Tg={}, seed {}.csv".\
            format(args.leave, args.numusers, args.samedata, args.optype, args.train_ratio, args.val_ratio, args.test_ratio, args.sample_ratio,
           lr,bs, args.inc_bs, args.q, args.mwfed_c, cp, args.better_cp, bo_cp,total_rnd, args.seed)

keys = data_csv.keys()
with open(file_name, 'w') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(keys)
    writer.writerows(zip(*[data_csv[key] for key in keys]))


file_name = "data"+args.hetero+"/mwf_unseen_leave{}_nu{}_sd{}_opt{}, trainr{}, valr{}, testr{}, spr{}, eta{}, b{}, binc{}, " \
                               ", q{}, c{}, tau{}, Tl={}, Ts={}, Tg={}, seed {}.csv".\
            format(args.leave, args.numusers, args.samedata, args.optype, args.train_ratio, args.val_ratio, args.test_ratio, args.sample_ratio,
           lr,bs, args.inc_bs, args.q, args.mwfed_c, cp, args.better_cp, bo_cp,total_rnd, args.seed)

keys = data_csv2.keys()
with open(file_name, 'w') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(keys)
    writer.writerows(zip(*[data_csv2[key] for key in keys]))
