import numpy as np
import matplotlib.pyplot as plt
import math
import copy
import pathlib
import csv
import logging
import pandas as pd
from torchvision import datasets, transforms
import torch
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from params import args_parser
import util as util
import algos as al
import seaborn as sns
sns.set()
args = args_parser()

title = 'new_isclust{}_isloctune{}_isbetter{}_isleave{}_a{:.2f}_tr{:.1f}_vr{:.1f}_en{}_n{}_bo{}_rnd{}_lr{:.4f}_bs{}_cp{}_q{}_c{}_e{}'.format(
    args.isclust, args.local_tune, args.better, args.leave,
    args.alpha, args.train_ratio, args.val_ratio, args.ensize, args.size, args.local_train_ep,
    args.rounds, args.eta, args.bs, args.train_ep,
    args.q, args.mwfed_c, args.seed)

logging.info(title)
logging.getLogger('matplotlib.font_manager').disabled = True

f = open(args.dir+'/'+args.name+'/'+args.dataset+'/'+title+'.txt', 'w+')

dataset_train, dataset_val, dataset_test, dataset_unseentrain, \
            dataset_unseenval, dataset_unseentest = util.load_dataset(args)
net_glob_org = util.select_model(args, args.model)

args.ensize = len(dataset_train)

logging.info('####### Ensize for '+args.dataset+': '+str(args.ensize))
local_test_loss = []
local_val_loss = []
local_test_acc = []
local_model = []
better_local_test_loss = []
better_local_val_loss = []
better_local_test_acc = []

for i in range(args.ensize):
  net_glob = copy.deepcopy(net_glob_org)
  net_glob.train()

  acc, val_loss, test_loss = al.train_and_test(net_glob, dataset_train[i],dataset_val[i], dataset_test[i],args)

  local_test_loss.append(test_loss)
  local_val_loss.append(val_loss)
  local_test_acc.append(acc)
  local_model.append(copy.deepcopy(net_glob))

  if args.better == 1:
    acc, val_loss, test_loss = al.train_and_test(net_glob, dataset_train[i], dataset_val[i], dataset_test[i], args, local_epochs = 300)

    better_local_test_loss.append(test_loss)
    better_local_val_loss.append(val_loss)
    better_local_test_acc.append(acc)


unseen_local_test_loss = []
unseen_local_val_loss = []
unseen_local_test_acc = []
unseen_local_model = []
num_unseen_clients = len(dataset_unseentrain)
for i in range(num_unseen_clients):
  net_glob = copy.deepcopy(net_glob_org)
  net_glob.train()

  acc, val_loss, test_loss = al.train_and_test(net_glob, dataset_unseentrain[i],dataset_unseenval[i], dataset_unseentest[i],args)
  unseen_local_test_loss.append(test_loss)
  unseen_local_val_loss.append(val_loss)
  unseen_local_test_acc.append(acc)
  unseen_local_model.append(copy.deepcopy(net_glob))

logging.info('########Finished Local Trainiing')
metric_local_test_loss = np.average(local_test_loss)
metric_local_val_loss = np.average(local_val_loss)
unseen_metric_local_test_loss = np.average(unseen_local_test_loss)
unseen_metric_local_val_loss = np.average(unseen_local_val_loss)
# loss_train = {}
test_acc_global = {}
test_loss_global = {}
pos_test_loss_global = {}
train_loss_global = {}
parti_rate_global = {}
eta_global = {}
weights_algo = {}

test_acc_parti_global= {}
test_loss_parti_global = {}

unseen_test_acc_global = {}
unseen_test_loss_global = {}
unseen_pos_test_loss_global = {}
unseen_train_loss_global = {}
unseen_parti_rate_global = {}
unseen_test_acc_parti_global= {}
unseen_test_loss_parti_global = {}

##################################
### Main: Server side optimization and testing
#################################
# Initiate the NN

n = args.ensize
s0 = 'FedAvg'
s1 = 'Sigmoid'
s3 = 'Leave'
s4 = 'Scaffold'
s5 = 'FedProx'
s7 = 'PerFedAvg'
s6 = 'qFFL'
s8 = 'MW-Fed'

incentive_algs = [s1, s7, s0, s4, s5, s6, s8]


data_csv = {}
data_csv2 = {}
data_csv3 = {}
data_csv4 = {}

for alg in incentive_algs:

    logging.info("Starting "+alg)
    f.write("Starting "+alg + '\n')
    net_glob = copy.deepcopy(net_glob_org)

    # Initialize params for Scaffold
    if alg == 'Scaffold':
        control_c = np.zeros(np.size(net_glob))
        control_locals = {i: np.zeros(np.size(net_glob)) for i in range(n)}
        control_local_lists = []

    if alg == 'MW-Fed':
        scale_lists = [1/args.ensize for _ in range(args.ensize)]

    net_glob.train()

    # copy weights
    w_glob = net_glob.state_dict()

    test_acc = []
    test_loss = []
    pos_test_loss = []
    parti_rate = []
    train_loss = []
    test_acc_parti = []
    test_loss_parti = []

    unseen_test_acc = []
    unseen_test_loss = []
    unseen_pos_test_loss = []
    unseen_parti_rate = []
    unseen_train_loss = []
    unseen_test_acc_parti = []
    unseen_test_loss_parti = []

    for t in range(args.rounds):

        grad_locals, val_loss_locals = [], []

        # for i in range(n):
        np.random.seed(t)
        ind = np.random.choice(n, args.size, replace=False)
        # ind = [i for i in range(n)]

        if alg == 'MW-Fed':
            scale_tmp1 = []
            for client in ind:
                _, val_loss = al.test_img(net_glob, dataset_val[client], args)
                if val_loss>=local_val_loss[client]:
                    scale_lists[client] *= args.mwfed_c

                scale_tmp1.append(scale_lists[client])

            scale_tmp1 = np.array(scale_tmp1)/sum(scale_tmp1)


        scale_qffl = 0.0
        mw_scale_sum = 0.0
        for cli_idx, i in enumerate(ind):
            val_acc, val_loss = al.test_img(net_glob, dataset_val[i], args)
            val_loss_locals.append(val_loss)

            local = al.LocalUpdate(args=args, dataset=dataset_train[i])
            if alg == 'FedProx':
                grad = local.train_and_sketch_fedprox(copy.deepcopy(net_glob))

            elif alg == 'PerFedAvg':
                grad = local.train_and_sketch_perfedavg(copy.deepcopy(net_glob))

            elif alg == 'MW-Fed':
                grad, mw_scale = local.train_and_sketch_mwfed(copy.deepcopy(net_glob), scale_tmp1[cli_idx])
                mw_scale_sum += mw_scale

            elif alg == 'qFFL':
                _, train_loss_qffl = al.test_img(net_glob, dataset_train[i], args)
                grad = local.train_and_sketch_qffl(copy.deepcopy(net_glob))
                scale_qffl += args.q*(train_loss_qffl)**(args.q-1)*np.linalg.norm(grad)**2\
                              +(1/args.eta)*(train_loss_qffl)**args.q
                grad *= (train_loss_qffl)**args.q

            else:
                grad = local.train_and_sketch(copy.deepcopy(net_glob))

            if alg == 'Scaffold':
                prev_ci = copy.deepcopy(control_locals[i])
                grad += args.eta*args.train_ep*(control_locals[i]-control_c)
                control_locals[i] = prev_ci-control_c+(1/(args.eta*args.train_ep))*(-grad)
                control_local_lists.append(control_locals[i] - prev_ci)

            grad_locals.append(copy.deepcopy(grad))


        with torch.no_grad():
            #control_local_lists
            if alg == 'Scaffold':
                control_c += 1/n*(np.average(control_local_lists))

            w_avg, scale_tmp = al.aggr_func(alg, ind, grad_locals, val_loss_locals, local_val_loss)

            if alg == 'qFFL':
                w_avg = w_avg/scale_qffl

            if alg == 'MW-Fed':
                w_avg = w_avg/mw_scale_sum

            w = torch.tensor(w_avg).type(torch.FloatTensor)
            w_vec_estimate = parameters_to_vector(net_glob.parameters()) + 1 * w.to(args.device)
            vector_to_parameters(w_vec_estimate, net_glob.parameters())


        if (t % 1 == 0):        # Recording frequencies

            # Evaluate for seen clients
            test_acc_tmp = []
            test_loss_tmp = []
            train_loss_tmp = []

            count = 0
            pc = []
            sum_pos_test_loss = 0
            test_acc_par = 0
            test_loss_par = 0

            all_test_accs = []
            part_test_accs = []
            all_part_test_accs = []
            for i in range(n):
                test_acc_i, test_loss_i = al.test_img(net_glob, dataset_test[i], args)
                train_acc_i, train_loss_i = al.test_img(net_glob, dataset_train[i], args)

                all_test_accs.append(test_acc_i)

                if (test_loss_i < local_test_loss[i]):
                    count = count + 1
                    pc.append(i)
                    test_acc_par += test_acc_i
                    test_loss_par += test_loss_i
                    part_test_accs.append(test_acc_i)
                    all_part_test_accs.append(test_acc_i)

                else:
                    sum_pos_test_loss += test_loss_i - local_test_loss[i]
                    test_acc_par += local_test_acc[i]
                    test_loss_par += local_test_loss[i]
                    all_part_test_accs.append(local_test_acc[i])

                test_acc_tmp.append(test_acc_i)
                test_loss_tmp.append(test_loss_i)
                train_loss_tmp.append(train_loss_i)

            avg_acc_test = np.sum(test_acc_tmp) / n             # THIS 1
            avg_loss_test = np.sum(test_loss_tmp) / n
            avg_loss_train = np.sum(train_loss_tmp) / n
            sum_pos_test_loss = sum_pos_test_loss / n

            test_loss_par = test_loss_par / n
            test_acc_par = test_acc_par / n                     # THIS 2

            pr = count / n                                      # THIS 3

            test_acc.append(avg_acc_test)                       # 1
            test_loss.append(avg_loss_test)
            pos_test_loss.append(sum_pos_test_loss)
            train_loss.append(avg_loss_train)
            parti_rate.append(pr)                               # 3

            test_acc_parti.append(test_acc_par)                 # 2
            test_loss_parti.append(test_loss_par)

            unseen_test_acc_tmp = []
            unseen_test_loss_tmp = []
            unseen_train_loss_tmp = []

            unseen_count = 0
            # count_acc = 0
            unseen_pc = []
            unseen_sum_pos_test_loss = 0
            unseen_test_acc_par = 0
            unseen_test_loss_par = 0

            for i in range(num_unseen_clients):

                test_acc_i, test_loss_i = al.test_img(net_glob, dataset_unseentest[i], args)
                train_acc_i, train_loss_i = al.test_img(net_glob, dataset_unseentrain[i], args)


                if (test_loss_i < unseen_local_test_loss[i]):
                    unseen_count = unseen_count + 1
                    unseen_pc.append(i)
                    unseen_test_acc_par += test_acc_i
                    unseen_test_loss_par += test_loss_i


                else:
                    unseen_sum_pos_test_loss += test_loss_i - unseen_local_test_loss[i]
                    unseen_test_acc_par += unseen_local_test_acc[i]
                    unseen_test_loss_par += unseen_local_test_loss[i]

                unseen_test_acc_tmp.append(test_acc_i)
                unseen_test_loss_tmp.append(test_loss_i)
                unseen_train_loss_tmp.append(train_loss_i)


            unseen_test_acc.append(np.sum(unseen_test_acc_tmp) / num_unseen_clients)  # 1
            unseen_test_loss.append(np.sum(unseen_test_loss_tmp) / num_unseen_clients)
            unseen_pos_test_loss.append(unseen_sum_pos_test_loss / num_unseen_clients)
            unseen_train_loss.append(np.sum(unseen_train_loss_tmp) / num_unseen_clients)
            unseen_parti_rate.append(unseen_count / num_unseen_clients)  # 3

            unseen_test_acc_parti.append(unseen_test_acc_par / num_unseen_clients)  # 2
            unseen_test_loss_parti.append(unseen_test_loss_par / num_unseen_clients)

    test_acc_global[alg] = test_acc                     # 1
    test_loss_global[alg] = test_loss
    pos_test_loss_global[alg] = pos_test_loss
    train_loss_global[alg] = train_loss
    parti_rate_global[alg] = parti_rate                 # 3

    test_acc_parti_global[alg] = test_acc_parti         # 2
    test_loss_parti_global[alg] = test_loss_parti

    weights_algo[alg] = copy.deepcopy(net_glob)

    data_csv[alg+'_test'] = test_acc_global[alg]
    data_csv[alg + '_testloss'] = test_loss_global[alg]
    data_csv[alg + '_trainloss'] = train_loss_global[alg]
    data_csv[alg + '_partest'] = test_acc_parti_global[alg]
    data_csv[alg+'_parti'] = parti_rate_global[alg]
    data_csv[alg + '_parti'] = parti_rate_global[alg]
    data_csv4[alg + '_metricloctest'] = metric_local_test_loss
    data_csv4[alg + '_metriclocval'] = metric_local_val_loss

    data_csv2[alg + '_allacc'] = all_test_accs
    data_csv2[alg + '_partacc'] = part_test_accs
    data_csv2[alg + '_allpartacc'] = all_part_test_accs

    # Unseen Logging
    unseen_test_acc_global[alg] = unseen_test_acc  # 1
    unseen_test_loss_global[alg] = unseen_test_loss
    unseen_pos_test_loss_global[alg] = unseen_pos_test_loss
    unseen_train_loss_global[alg] = unseen_train_loss
    unseen_parti_rate_global[alg] = unseen_parti_rate  # 3

    unseen_test_acc_parti_global[alg] = unseen_test_acc_parti  # 2
    unseen_test_loss_parti_global[alg] = unseen_test_loss_parti

    data_csv3[alg + '_test'] = unseen_test_acc_global[alg]
    data_csv3[alg + '_trainloss'] = unseen_train_loss_global[alg]
    data_csv3[alg + '_partest'] = unseen_test_acc_parti_global[alg]
    data_csv3[alg + '_parti'] = unseen_parti_rate_global[alg]

    if args.local_tune == 1:

        test_acc_tmp = []
        test_loss_tmp = []
        train_loss_tmp = []

        count = 0
        pc = []
        sum_pos_test_loss = 0
        test_acc_par = 0
        test_loss_par = 0

        for i in range(n):

            local = al.LocalUpdate(args=args, dataset=dataset_train[i])
            tuned_model = local.train_and_sketch_local(copy.deepcopy(net_glob))

            test_acc_i, test_loss_i = al.test_img(tuned_model, dataset_test[i], args)
            train_acc_i, train_loss_i = al.test_img(tuned_model, dataset_train[i], args)

            if (test_loss_i < local_test_loss[i]):
                count = count + 1
                pc.append(i)
                test_acc_par += test_acc_i
                test_loss_par += test_loss_i

            else:
                sum_pos_test_loss += test_loss_i - local_test_loss[i]
                test_acc_par += local_test_acc[i]
                test_loss_par += local_test_loss[i]

            test_acc_tmp.append(test_acc_i)
            test_loss_tmp.append(test_loss_i)
            train_loss_tmp.append(train_loss_i)


        avg_acc_test = np.sum(test_acc_tmp) / n             # THIS 1
        avg_loss_test = np.sum(test_loss_tmp) / n
        avg_loss_train = np.sum(train_loss_tmp) / n
        sum_pos_test_loss = sum_pos_test_loss / n

        test_loss_par = test_loss_par / n
        test_acc_par = test_acc_par / n                     # THIS 2

        pr = count / n                                      # THIS 3
        #count_acc_pr = count_acc / n
        local_tune_print = "LOCAL TUNED PR {} PRA {} ACC {}".format(pr, test_acc_par, avg_acc_test)
        print(local_tune_print )
        f.write("-------------------" +local_tune_print + '\n')
        unseen_test_acc_tmp = []
        unseen_test_loss_tmp = []
        unseen_train_loss_tmp = []

        unseen_count = 0
        # count_acc = 0
        unseen_pc = []
        unseen_sum_pos_test_loss = 0
        unseen_test_acc_par = 0
        unseen_test_loss_par = 0

        for i in range(num_unseen_clients):

            local = al.LocalUpdate(args=args, dataset=dataset_unseentrain[i])
            tuned_model = local.train_and_sketch_local(copy.deepcopy(net_glob))

            test_acc_i, test_loss_i = al.test_img(tuned_model, dataset_unseentest[i], args)
            train_acc_i, train_loss_i = al.test_img(tuned_model, dataset_unseentrain[i], args)

            if (test_loss_i < unseen_local_test_loss[i]):
                unseen_count = unseen_count + 1
                unseen_pc.append(i)
                unseen_test_acc_par += test_acc_i
                unseen_test_loss_par += test_loss_i

            else:
                unseen_sum_pos_test_loss += test_loss_i - unseen_local_test_loss[i]
                unseen_test_acc_par += unseen_local_test_acc[i]
                unseen_test_loss_par += unseen_local_test_loss[i]

            unseen_test_acc_tmp.append(test_acc_i)
            unseen_test_loss_tmp.append(test_loss_i)
            unseen_train_loss_tmp.append(train_loss_i)


        unseen_avg_acc_test = np.sum(unseen_test_acc_tmp) / num_unseen_clients             # THIS 1
        unseen_test_acc_par = unseen_test_acc_par / num_unseen_clients                     # THIS 2
        unseen_pr = unseen_count / num_unseen_clients                                    # THIS 3

        unseen_local_write = "UNSEEN LOCAL TUNED PR {} PRA {} ACC {}".format(unseen_pr,  unseen_test_acc_par, unseen_avg_acc_test)
        print(unseen_local_write)
        f.write("-------------------" + unseen_local_write + '\n')

f.close()
save_path = args.dir+'/'+args.name+'/'+args.dataset+'/'
file_name = 'new_isclust{}_isloctune{}_isbetter{}_isleave{}_a{:.2f}_tr{:.1f}_vr{:.1f}_en{}_n{}_bo{}_rnd{}_lr{:.4f}_' \
            'bs{}_cp{}_q{}_c{}_e{}.csv'.format(
    args.isclust, args.local_tune, args.better, args.leave,
    args.alpha, args.train_ratio, args.val_ratio, args.ensize, args.size, args.local_train_ep,
    args.rounds, args.eta, args.bs, args.train_ep,
    args.q, args.mwfed_c, args.seed)
pathlib.Path(save_path).mkdir(parents=True, exist_ok=True)
file_name = save_path+file_name
keys = data_csv.keys()
with open(file_name, 'w') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(keys)
    writer.writerows(zip(*[data_csv[key] for key in keys]))

save_path = args.dir+'/'+args.name+'/'+args.dataset+'/'
file_name = 'new_unseenisclust{}_isloctune{}_isbetter{}_isleave{}_a{:.2f}_tr{:.1f}_vr{:.1f}_en{}_n{}_bo{}_rnd{}_lr{:.4f}_' \
            'bs{}_cp{}_q{}_c{}_e{}.csv'.format(
    args.isclust, args.local_tune, args.better, args.leave,
    args.alpha, args.train_ratio, args.val_ratio, args.ensize, args.size, args.local_train_ep,
    args.rounds, args.eta, args.bs, args.train_ep,
    args.q, args.mwfed_c, args.seed)
unseen_file_name = save_path+file_name
unseen_keys = data_csv3.keys()
with open(unseen_file_name, 'w') as unseen_csvfile:
    writer = csv.writer(unseen_csvfile)
    writer.writerow(unseen_keys)
    writer.writerows(zip(*[data_csv3[key] for key in unseen_keys]))