# Modified from: https://github.com/pliang279/LG-FedAvg/blob/master/main_fed.py
# credit goes to: Paul Pu Liang

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import copy
import itertools
import numpy as np
import pandas as pd
import torch
from torch import nn

from utils.options import args_parser
from utils.train_utils import get_data, get_model, read_data
from models.Update import LocalUpdate
from models.test import test_img_local_all

import time

def main(args):

    lens = np.ones(args.num_users)

    dataset_train_g, dataset_test_g, dict_users_train_g, dict_users_test_g, \
    dataset_train_c, dataset_test_c, dict_users_train_c, dict_users_test_c = get_data(args)
    for idx in dict_users_train_g.keys():
        np.random.shuffle(dict_users_train_g[idx])
    for idx in dict_users_train_c.keys():
        np.random.shuffle(dict_users_train_c[idx])

    print(args.alg)

    # build model
    net_glob, net_trans = get_model(args)
    net_glob.train()
    net_trans.train()
    if args.load_fed != 'n':
        fed_model_path = './save/' + args.load_fed + '.pt'
        net_glob.load_state_dict(torch.load(fed_model_path))

    total_num_layers = len(net_glob.state_dict().keys()) + len(net_trans.state_dict().keys())
    print(net_glob.state_dict().keys())
    net_keys = [*net_glob.state_dict().keys()] + [*net_trans.state_dict().keys()]

    # specify the representation parameters (in w_glob_keys) and head parameters (all others)
    if args.alg == 'fedAKIE' or args.alg == 'fedper':
        if 'cifar' in args.dataset:
            w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2]]
        elif 'mnist' in args.dataset:
            w_glob_keys = [net_glob.weight_keys[i] for i in [0,1,2]]
        elif 'femnist' in args.dataset and args.model == 'cnn':
            w_glob_keys = [net_glob.weight_keys[i] for i in [0,1,2]]
        elif 'femnist' in args.dataset and args.model == 'mlp':
            w_glob_keys = [net_glob.weight_keys[i] for i in [0,1,2]]
        elif 'sent140' in args.dataset:
            w_glob_keys = [net_keys[i] for i in [0,1,2,3,4,5]]
        else:
            w_glob_keys = net_keys[:-2]
    elif args.alg == 'lg':
        if 'cifar' in  args.dataset:
            w_glob_keys = [net_glob.weight_keys[i] for i in [1,2]]
        elif 'mnist' in args.dataset:
            w_glob_keys = [net_glob.weight_keys[i] for i in [2,3]]
        elif 'sent140' in args.dataset:
            w_glob_keys = [net_keys[i] for i in [0,6,7]]
        else:
            w_glob_keys = net_keys[total_num_layers - 2:]

    if args.alg == 'fedavg' or args.alg == 'prox':
        w_glob_keys = []
    
    print(total_num_layers)
    print(w_glob_keys)
    print(net_keys)
    if args.alg == 'fedAKIE' or args.alg == 'fedper' or args.alg == 'lg':
        num_param_glob = 0
        num_param_local = 0
        for key in net_glob.state_dict().keys():
            num_param_local += net_glob.state_dict()[key].numel()
            print(num_param_local)
            if key in w_glob_keys:
                num_param_glob += net_glob.state_dict()[key].numel()

        for key in net_trans.state_dict().keys():
            num_param_local += net_trans.state_dict()[key].numel()

        percentage_param = 100 * float(num_param_glob) / num_param_local
        print('# Params: {} (local), {} (global); Percentage {:.2f} ({}/{})'.format(
            num_param_local, num_param_glob, percentage_param, num_param_glob, num_param_local))
    print("learning rate, batch size: {}, {}".format(args.lr, args.local_bs))

    # generate list of local models for each user
    net_local_list = []
    w_locals = {}
    for user in range(args.num_users):
        w_local_dict = {}
        for key in net_glob.state_dict().keys():
            w_local_dict[key] =net_glob.state_dict()[key]
        for key in net_trans.state_dict().keys():
            w_local_dict[key] =net_trans.state_dict()[key]
        w_locals[user] = w_local_dict

    # training
    indd = None      # indices of embedding for sent140
    loss_train = []
    accs = []
    times = []
    accs10 = 0
    accs10_glob = 0
    start = time.time()
    import os
    out_path = args.save_path + args.alg + '/' + args.dataset + '/'
    if not os.path.exists(out_path):
        os.makedirs(out_path)
    out_file = out_path + '/' + str(args.num_users) + '_' + str(args.shard_per_user) \
               + '_' + str(args.local_bs) + '_' + str(args.frac) + '_' + str(args.times) + '.txt'
    out_file = open(out_file, "w")
    for iter in range(args.epochs+1):
        print('The training communication is ', iter)
        # start_time = time.time()
        w_glob = {}
        loss_locals = []
        m = max(int(args.frac * args.num_users), 1)
        if iter == args.epochs:
            m = args.num_users

        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        w_keys_epoch = w_glob_keys
        times_in = []
        total_len=0
        all_replace, all_total = 0, 0
        for ind, idx in enumerate(idxs_users):
            # print(iter, idx)
            start_in = time.time()
            if idx >= (args.num_users//2):
                dataset_train = dataset_train_c
                dict_users_train = dict_users_train_c
            else:
                dataset_train = dataset_train_g
                dict_users_train = dict_users_train_g

            if args.epochs == iter:
                local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users_train[idx][:args.m_ft])
            else:
                local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users_train[idx][:args.m_tr])

            net_local = copy.deepcopy(net_glob)
            net_local_trans = copy.deepcopy(net_trans)
            w_local = net_local.state_dict()
            w_local_trans = net_trans.state_dict()
            if args.alg != 'fedavg' and args.alg != 'prox':
                for k in w_locals[idx].keys():
                    if k not in w_glob_keys:
                        w_local_trans[k] = w_locals[idx][k]
            net_local.load_state_dict(w_local)
            net_local_trans.state_dict(w_local_trans)
            last = iter == args.epochs
            w_local, w_local_trans, loss, indd, cur_replace, cur_total = local.train(net=net_local.to(args.device),
                                                                                     net_trans=net_local_trans.to(
                                                                                         args.device),
                                                                                     net_global=net_glob.to(
                                                                                         args.device),
                                                                                     idx=idx,
                                                                                     w_glob_keys=w_glob_keys,
                                                                                     lr=args.lr, last=last)
            all_total += cur_total
            all_replace += cur_replace
            loss_locals.append(copy.deepcopy(loss))
            total_len += lens[idx]
            if len(w_glob) == 0:
                w_glob = copy.deepcopy(w_local)
                for k,key in enumerate(net_glob.state_dict().keys()):
                    w_glob[key] = w_glob[key]*lens[idx]
                    w_locals[idx][key] = w_local[key]
            else:
                for k,key in enumerate(net_glob.state_dict().keys()):
                    if key in w_glob_keys:
                        w_glob[key] += w_local[key]*lens[idx]
                    else:
                        w_glob[key] += w_local[key]*lens[idx]
                    w_locals[idx][key] = w_local[key]

            for k,key in enumerate(net_trans.state_dict().keys()):
                w_locals[idx][key] = w_local_trans[key]

            times_in.append( time.time() - start_in )
        loss_avg = sum(loss_locals) / len(loss_locals)
        loss_train.append(loss_avg)

        # get weighted average for global weights
        for k in net_glob.state_dict().keys():
            w_glob[k] = torch.div(w_glob[k], total_len)

        w_local = net_glob.state_dict()
        for k in w_glob.keys():
            w_local[k] = w_glob[k]
        if args.epochs != iter:
            net_glob.load_state_dict(w_glob)

        # if iter % args.test_freq==args.test_freq-1 or iter>=args.epochs-10:
        if iter % args.test_freq == args.test_freq - 1:
            if times == []:
                times.append(max(times_in))
            else:
                times.append(times[-1] + max(times_in))

            acc_test_g, loss_test_g = test_img_local_all(net_glob, net_trans, args, dataset_test_g, dict_users_test_g,
                                                        w_glob_keys=w_glob_keys, w_locals=w_locals,indd=indd,
                                                         dataset_train=dataset_train_g, dict_users_train=dict_users_train_g,
                                                         return_all=False, color=False)

            acc_test_c, loss_test_c = test_img_local_all(net_glob, net_trans, args, dataset_test_c, dict_users_test_c,
                                                        w_glob_keys=w_glob_keys, w_locals=w_locals,indd=indd,
                                                         dataset_train=dataset_train_c, dict_users_train=dict_users_train_c,
                                                         return_all=False, color=True)

            acc_test = (acc_test_g + acc_test_c) / 2
            loss_test = (loss_test_g + loss_test_c) / 2
            accs.append(acc_test)
            # for algs which learn a single global model, these are the local accuracies (computed using the locally updated versions of the global model at the end of each round)
            if iter != args.epochs:
                print(
                    'Round {:3d}, Train loss: {:.3f}, Test loss: {:.3f}, Test accuracy: {:.2f}, Mask Ratio: {:.2f}'.format(
                        iter, loss_avg, loss_test, acc_test, (all_replace/m)/(all_total/m)))
                log_str = 'Round {:3d}, Train loss: {:.3f}, Test loss: {:.3f}, Test accuracy: {:.2f}, Mask Ratio: {:.2f}'.format(
                        iter, loss_avg, loss_test, acc_test, (all_replace/m)/(all_total/m))
            else:
                # in the final round, we sample all users, and for the algs which learn a single global model, we fine-tune the head for 10 local epochs for fair comparison with FedRep
                print('Final Round, Train loss: {:.3f}, Test loss: {:.3f}, Test accuracy: {:.2f}'.format(
                        loss_avg, loss_test, acc_test))
                log_str = 'Final Round, Train loss: {:.3f}, Test loss: {:.3f}, Test accuracy: {:.2f}'.format(
                        loss_avg, loss_test, acc_test)
            if iter >= args.epochs-10 and iter != args.epochs:
                accs10 += acc_test/10

            out_file.write(log_str + "\n")
            out_file.flush()

            if iter >= args.epochs-10 and iter != args.epochs:
                accs10_glob += acc_test/10


    print('Average accuracy final 10 rounds: {}'.format(accs10))
    if args.alg == 'fedavg' or args.alg == 'prox':
        print('Average global accuracy final 10 rounds: {}'.format(accs10_glob))
    end = time.time()
    print(end-start)
    print(times)
    print(accs)


if __name__ == '__main__':
    args = args_parser()
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    points = [[200, 3], [200, 4], [200, 5]]
    # points = [[100, 2], [100, 3], [100, 4], [100, 5]]

    for point in points:
        for t in range(10):
            args.times = t
            args.num_users = point[0]
            args.shard_per_user = point[1]
            main(args)
