#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6


import os
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm

import torch
from tensorboardX import SummaryWriter

from options import args_parser
from update import LocalUpdate, test_inference
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar,CNNCifarBN,CNNCifarV,modelC,VGG
from utils import get_dataset, get_dataset_test, get_dataset_3,get_dataset_5,get_dataset_fair,average_weights, exp_details
import pdb

#np.random.seed(3)
#torch.manual_seed(3)
#if torch.cuda.is_available(): torch.cuda.manual_seed_all(3)


if __name__ == '__main__':
    start_time = time.time()

    # define paths
    path_project = os.path.abspath('..')
    logger = SummaryWriter('../logs')

    args = args_parser()
    exp_details(args)

    if args.gpu:
        device = torch.device('cuda:' + args.gpu)
        #device = torch.device('cuda:1')
        #torch.cuda.set_device(args.gpu)
    else:
        device = 'cpu'
    #device_idx = 'cuda:'+args.gpu if args.gpu else 'cpu'
    #device_idx = torch.device(device_idx)



    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.backends.cudnn.deterministic = True

    # load dataset and user groups
    #train_dataset, test_dataset, user_groups = get_dataset(args)
    #train_dataset, test_dataset, user_groups = get_dataset_3(args)
    #train_dataset, test_dataset, user_groups = get_dataset_test(args)
    #train_dataset, test_dataset, user_groups = get_dataset_5(args)
    #train_dataset, test_dataset, user_groups = get_dataset_6(args)
    train_dataset, test_dataset, user_groups = get_dataset_fair(args)

    # BUILD MODEL
    if args.model == 'cnn':
        # Convolutional neural netork
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifarV(args=args)

    elif args.model == 'mlp':
        # Multi-layer preceptron
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
        global_model = MLP(dim_in=len_in, dim_hidden=200,
                               dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()
    print(global_model)

    # copy weights
    global_weights = global_model.state_dict()

    # Training
    train_loss, train_accuracy, train_accuracy_ontrain = [], [],[]
    val_acc_list, net_list = [], []
    cv_loss, cv_acc = [], []
    print_every = 2
    val_loss_pre, counter = 0, 0

    external_test_acc = list()
    train_accuracy_var = list()
    selected_users = set()
    all_list_acc = list()

    for epoch in tqdm(range(args.epochs)):
        local_weights, local_losses = [], []
        print(f'\n | Global Training Round : {epoch+1} |\n')

        global_model.train()
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        selected_users = selected_users.union(set(idxs_users))

        for idx in idxs_users:
            local_model = LocalUpdate(args=args, dataset=train_dataset,
                                      idxs=user_groups[idx], logger=logger)
            w, loss = local_model.update_weights(
                model=copy.deepcopy(global_model), global_round=epoch)
            local_weights.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss))

        # update global weights
        global_weights = average_weights(local_weights)

        # update global weights
        global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        # Calculate avg training accuracy over all users at every epoch
        list_acc, list_loss = [], []
        list_acc_ontrain,list_loss_ontrain=[],[]
        global_model.eval()
        for cu in range(args.num_users):
            local_model = LocalUpdate(args=args, dataset=train_dataset,
                                      idxs=user_groups[cu], logger=logger)

            acc, loss = local_model.inference(model=global_model)
            #if acc == 0:
            #    pdb.set_trace()
            #if acc > 0.5:
            #    pdb.set_trace()
            list_acc.append(acc)
            list_loss.append(loss)
            #acc, loss = local_model.inference_ontrain(model=global_model)
            # if acc == 0:
            #    pdb.set_trace()
            # if acc > 0.5:
            #    pdb.set_trace()
            list_acc_ontrain.append(acc)
            #list_loss_ontrain.append(loss)


        all_list_acc.append(list_acc)
        train_accuracy.append(sum(list_acc)/len(list_acc))
        #train_accuracy_ontrain.append(sum(list_acc_ontrain)/len(list_acc_ontrain))
        #pdb.set_trace()

        test_acc, test_loss = test_inference(args, global_model, test_dataset)
        external_test_acc.append(test_acc)

        train_accuracy_var.append(np.var(list_acc))


        # print global training loss after every 'i' rounds
        if (epoch+1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))


        """
        if (epoch+1) % 10 == 0:
            test_acc, test_loss = test_inference(args, global_model, test_dataset)
            print(f' \n Results after {epoch} global rounds of training:')
            print("|---- Avg Train Accuracy: {:.2f}%".format(100 * train_accuracy[-1]))
            print("|---- Test Accuracy: {:.2f}%".format(100 * test_acc))
            print(list_acc)
            array_list_acc = np.asarray(list_acc)
            #if epoch == args.epochs - 1:
            #np.savetxt("list_acc_record/"+args.dataset+"_main_acclist_noshu"+str(epoch+1)+".txt", array_list_acc)
        """

    # Test inference after completion of training
    test_acc, test_loss = test_inference(args, global_model, test_dataset)


    print(f' \n Results after {args.epochs} global rounds of training:')
    print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
    print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))


    # Saving the objects train_loss and train_accuracy:
    """
    file_name = '../save/baselines/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_Lr[{}].pkl'.\
        format(args.dataset, args.model, args.epochs, args.frac, args.iid,
               args.local_ep, args.local_bs,args.lr)
    """

    file_name = '../save/plots/ablation/local_epochs/fair/base/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_Lr[{}].pkl'.\
        format(args.dataset, args.model, args.epochs, args.frac, args.iid,
               args.local_ep, args.local_bs,args.lr)

    with open(file_name, 'wb') as f:
        pickle.dump([external_test_acc, train_accuracy,train_accuracy_var,all_list_acc], f)

    print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))

    # PLOTTING (optional)
    # import matplotlib
    # import matplotlib.pyplot as plt
    # matplotlib.use('Agg')

    # Plot Loss curve
    # plt.figure()
    # plt.title('Training Loss vs Communication rounds')
    # plt.plot(range(len(train_loss)), train_loss, color='r')
    # plt.ylabel('Training loss')
    # plt.xlabel('Communication Rounds')
    # plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'.
    #             format(args.dataset, args.model, args.epochs, args.frac,
    #                    args.iid, args.local_ep, args.local_bs))
    #
    # # Plot Average Accuracy vs Communication rounds
    # plt.figure()
    # plt.title('Average Accuracy vs Communication rounds')
    # plt.plot(range(len(train_accuracy)), train_accuracy, color='k')
    # plt.ylabel('Average Accuracy')
    # plt.xlabel('Communication Rounds')
    # plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.
    #             format(args.dataset, args.model, args.epochs, args.frac,
    #                    args.iid, args.local_ep, args.local_bs))
