#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6


import os
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm
import random
from collections import deque
import datetime

import torch
from tensorboardX import SummaryWriter

from options import args_parser
from update import LocalUpdate, test_inference
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar, CharLSTM
from utils import get_dataset, average_weights, exp_details
from shkp import ShakeSpeare



if __name__ == '__main__':
    start_time = time.time()

    # define paths
    path_project = os.path.abspath('..')
    logger = SummaryWriter('../logs')

    args = args_parser()
    exp_details(args)
    
    if args.gpu:
        torch.cuda.set_device(args.gpu) # if error here, it means you have no suitable torch installed, please refer https://pytorch.org/get-started/locally/
    device = 'cuda' if args.gpu else 'cpu'

    # load dataset and user groups
    if args.dataset in ['mnist', 'cifar','fmnist']:
        train_dataset, test_dataset, user_groups = get_dataset(args)
    elif args.dataset == 'shakespeare':
        train_dataset = ShakeSpeare(train=True)
        test_dataset = ShakeSpeare(train=False)
        user_groups = train_dataset.get_client_dic()
        args.num_users = len(user_groups) # 139 users
        print(args.num_users)
        if args.iid:
            exit('Error: ShakeSpeare dataset is naturally non-iid')
        else:
            print("Warning: The ShakeSpeare dataset is naturally non-iid, you do not need to specify iid or non-iid")
    else:
        exit('Error: unrecognized dataset')
    
    # For MNIST and CIFAR10, each edge has 5 clients
    # For Shakespeare, each edge has 14 clients and the last edge has 13
    num_edges = args.num_edges
    step = args.step # clients/num_edges    

    # BUILD MODEL
    if args.model == 'cnn':
        # Convolutional neural netork
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifar(args=args)

    elif args.model == 'mlp':
        # Multi-layer preceptron
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
            global_model = MLP(dim_in=len_in, dim_hidden=64,dim_out=args.num_classes)
    
    elif args.model == 'lstm':
        global_model = CharLSTM(args=args)
        step = 14 # 139 clients / 10 edge servers
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()
    print(global_model)

    # copy weights
    global_weights = global_model.state_dict()

    # Training Parameters Initialization
    train_loss, train_accuracy = [], []
    val_acc_list, net_list = [], []
    cv_loss, cv_acc = [], []
    print_every = 2
    val_loss_pre, counter = 0, 0
    test_accuracy = []
    
    # write the result to file
    result_fn = args.pattern +"_"+ args.dataset +"_"+ str(args.hety) +"_"+ str(args.pg) +"_"+ str(args.local_ep)+"_" + str(args.epochs) +"_" + str(args.num_users)+"_" + "_"+str(datetime.datetime.now())
    with open('result.txt','a') as sr:
        sr.write(result_fn)
        sr.write('\n')
        sr.close()
    
    # JJ inital queue for global model candidates
    global_models_q = deque()
    global_models_q.append(copy.deepcopy(global_model))
    
    # JJ inital models for all edge servers
    edge_models = [copy.deepcopy(global_model)]*num_edges
    
    edge_model = copy.deepcopy(global_model)
    e2e_weights = global_model.state_dict()
    # Start training
    for epoch in tqdm(range(args.epochs)):
        global_model.train()
        # selected_edges_weights = []
        print(f'\n | Global Training Round : {epoch+1} |\n')
        
        # JJ select some edges to do training
        selected_edges = random.sample(range(num_edges),args.select_edges) # range creates sequence from 0-9 for edge servers
        print(selected_edges)
        for edge_index in selected_edges:
            # get clients for each selected edge server
            idxs_users = range(edge_index*step,min(len(user_groups),(edge_index+1)*step))
            if args.pattern == 'ece':
                edge_model.load_state_dict(e2e_weights)  
            edge_model.train().to(device)
            # FedAvg for all clients in edge server
            for ed_ep in range(args.edge_ep):
                local_weights = []
                for idx in idxs_users:
                    local_model = LocalUpdate(args=args, dataset=train_dataset,idxs=user_groups[idx], logger=logger)
                    w, loss = local_model.update_weights(model=copy.deepcopy(edge_model), global_round=epoch)
                    local_weights.append(copy.deepcopy(w))
                    #local_losses.append(copy.deepcopy(loss))
                edge_weights = average_weights(local_weights) # FedAvg weights
                edge_model.load_state_dict(edge_weights) # update edge server
            e2e_weights = edge_weights
            print('finish training for one edge!')
            #selected_edges_weights.append(copy.deepcopy(edge_weights))
        global_model.load_state_dict(e2e_weights)    
        
        # Test inference after completion of training
        test_acc, test_loss = test_inference(args, global_model, test_dataset)
        
        #print("|---- Test Loss: {}".format(test_loss))
        print('| Global Round : {}|---- Test Accuracy: {:.2f}%'.format(epoch,100*test_acc))
        test_accuracy.append(test_acc)

    with open('result.txt','a') as sr:
        for item in test_accuracy:
            acc_str = str(item)
            sr.write(acc_str)
            sr.write(',')
        sr.write('\n\n')
        """
            # Calculate avg training accuracy over all users at every epoch
            list_acc, list_loss = [], []
            global_model.eval()
            for c in range(args.num_users):
                local_model = LocalUpdate(args=args, dataset=train_dataset,
                                        idxs=user_groups[idx], logger=logger)
                acc, loss = local_model.inference(model=global_model)
                list_acc.append(acc)
                list_loss.append(loss)
            train_accuracy.append(sum(list_acc)/len(list_acc))

            # print global training loss after every 'i' rounds
            if (epoch+1) % print_every == 0:
                print(f' \nAvg Training Stats after {epoch+1} global rounds:')
                print(f'Training Loss : {np.mean(np.array(train_loss))}')
                print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))
        """
    
    """
    print(f' \n Results after {args.epochs} global rounds of training:')
    print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
    print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))

    # Saving the objects train_loss and train_accuracy:
    file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.\
        format(args.dataset, args.model, args.epochs, args.frac, args.iid,
               args.local_ep, args.local_bs)

    with open(file_name, 'wb') as f:
        pickle.dump([train_loss, train_accuracy], f)

    print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))
    """
    #PLOTTING (optional)
    # import matplotlib
    # import matplotlib.pyplot as plt
    # matplotlib.use('Agg')

    # Plot Loss curve
    # plt.figure()
    # plt.title('Training Loss vs Communication rounds')
    # plt.plot(range(len(train_loss)), train_loss, color='r')
    # plt.ylabel('Training loss')
    # plt.xlabel('Communication Rounds')
    # plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'.
    #             format(args.dataset, args.model, args.epochs, args.frac,
    #                    args.iid, args.local_ep, args.local_bs))
    #
    # Plot Average Accuracy vs Communication rounds
    # plt.figure()
    # plt.title('Average Accuracy vs Communication rounds')
    # plt.plot(range(len(test_accuracy)), test_accuracy, color='k')
    # plt.ylabel('Average Accuracy')
    # plt.xlabel('Communication Rounds')
    # plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.
    #             format(args.dataset, args.model, args.epochs, args.frac,
    #                    args.iid, args.local_ep, args.local_bs))
