#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6


import os
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm

import torch
from tensorboardX import SummaryWriter

from options import args_parser
from update import LocalUpdateIBP, test_inference_ibp,test_inference_ibp_check
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar,modelC,VGG,IBPMLP,IBPCNN_V3
from utils import *
from ibp_torch import Server_MLP,Server_V3
import pdb



if __name__ == '__main__':
    start_time = time.time()

    # define paths
    path_project = os.path.abspath('..')
    logger = SummaryWriter('../logs')

    args = args_parser()
    exp_details(args)


    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.backends.cudnn.deterministic = True


    if args.gpu:
        device = torch.device('cuda:' + args.gpu)
    else:
        device = 'cpu'


    # load dataset and user groups
    if args.middleiid == 0:
        train_dataset, test_dataset, user_groups = get_dataset(args)
    elif args.middleiid == 1:
        train_dataset, test_dataset, user_groups = get_dataset_test(args)



    # BUILD MODEL
    if args.model == 'cnn':
        # Convolutional neural netork
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifarV(args=args)

    elif args.model == 'ibpmlp':
        # Multi-layer preceptron
        img_size = train_dataset[0][0].shape
        len_in = 1
        a_prior = args.a_prior
        lambda_prior = args.lambda_prior
        lambda_post = args.lambda_post
        p_threshold = args.p_threshold
        model_arch = [(785,args.truncation,200),(201,args.truncation,args.num_classes)]
        global_model = Server_MLP(model_arch,a_prior,lambda_prior,lambda_post,p_threshold)

    elif args.model == 'ibpcnn':

        img_size = train_dataset[0][0].shape
        len_in = 1
        a_prior = args.a_prior
        lambda_prior = args.lambda_prior
        lambda_post = args.lambda_post
        p_threshold = args.p_threshold
        if args.dataset == 'mnist':
            conv_arch = [(args.num_channels, 10, 5), (10, 20, 5)] # for mnist
            model_arch = [('conv',conv_arch[0][0] * conv_arch[0][2] ** 2, args.conv_truncation, conv_arch[0][1]),
                          ('conv',conv_arch[1][0] * conv_arch[1][2] ** 2, args.conv_truncation, conv_arch[1][1]),

                      ('mlp',321,args.truncation,50),
                      ('last',50,None,args.num_classes)]
        elif args.dataset == 'fmnist':
            conv_arch = [(args.num_channels, 16, 5), (16, 32, 5)]

            model_arch = [('conv',conv_arch[0][0] * conv_arch[0][2] ** 2, args.conv_truncation, conv_arch[0][1]),
                          ('conv',conv_arch[1][0] * conv_arch[1][2] ** 2, args.conv_truncation, conv_arch[1][1]),
                      ('last',7*7*32,None,args.num_classes)]
        elif args.dataset =='cifar':

            conv_arch = [(args.num_channels, 16, 3), (16, 16, 3)]

            model_arch = [('conv',conv_arch[0][0] * conv_arch[0][2] ** 2, args.conv_truncation, conv_arch[0][1]),
                          ('conv',conv_arch[1][0] * conv_arch[1][2] ** 2, args.conv_truncation, conv_arch[1][1]),

                      ('mlp',16*6*6 + 1,args.truncation,120),
                      ('mlp',121,args.truncation1,84),
                      ('last',84,None,args.num_classes)]




        global_model = Server_V3(model_arch,a_prior,lambda_prior,lambda_post,p_threshold)
    else:
        exit('Error: unrecognized model')


    # Training
    train_loss, train_accuracy = [], []
    val_acc_list, net_list = [], []
    cv_loss, cv_acc = [], []
    print_every = 2
    val_loss_pre, counter = 0, 0
    selected_users = set()

    clients_info = dict()
    clients_bayes_info = dict()
    clients_masks_info = dict()


    train_accuracy_var = list()
    external_test_acc = list()
    external_test_loss = list()
    all_list_acc = list()

    for epoch in tqdm(range(args.epochs)):
        local_weights, local_losses = dict(), []
        print('\n | Global Training Round : {} |\n'.format(epoch+1))

        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        local_list_acc, local_list_loss = [], []

        for idx in idxs_users:
            # for all layers: Wa,Wb
            local_model_weights = global_model.send_weights(idx)
            local_bayes = clients_bayes_info.get(idx, None)


            if args.model == 'ibpmlp':
                local_assembly = IBPMLP_V3(model_arch,local_model_weights,global_model.a_prior,global_model.lambda_post,
                                    global_model.lambda_prior, global_model.p_threshold,local_bayes)
            elif args.model == 'ibpcnn':
                local_assembly = IBPCNN_V3(model_arch,conv_arch,local_model_weights,global_model.a_prior,global_model.lambda_post,
                                    global_model.lambda_prior, global_model.p_threshold,local_bayes)


            # Data construction in LocalUpdateIBP
            local_model = LocalUpdateIBP(args=args, dataset=train_dataset,
                                      idxs=user_groups[idx], logger=logger)

            w, loss,local_masks = local_model.update_weights(model=local_assembly, global_round=epoch)

            copy_version = copy.deepcopy(w)


            if args.model == 'ibpmlp':
                current_client_info = update_client_info(copy_version, model_arch[:-1])
                current_client_bayes_info = update_client_bayes_info(copy_version, model_arch[:-1])
            else:
                current_client_info = update_client_info_general(copy_version, model_arch[:-1])
                current_client_bayes_info = update_client_bayes_info_general(copy_version, model_arch[:-1])

            clients_info[idx] = current_client_info
            clients_bayes_info[idx] = current_client_bayes_info
            clients_masks_info[idx] = local_masks

            local_weights[idx] = [copy_version,local_masks]
            local_losses.append(copy.deepcopy(loss))


        selected_users = selected_users.union(set(idxs_users))

        # update global weights
        global_model.plain_update_weights(local_weights, epoch)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        # Calculate avg training accuracy over all users at every epoch
        list_acc, list_loss = [], []
        list_external_acc = []

        for s_idx in range(args.num_users):
            local_model_weights = global_model.send_weights(s_idx)

            if s_idx in selected_users:
                rts = clients_info[s_idx]['rts']

                local_model_weights = replace_with_local_r_only(local_model_weights[:-1], rts) + [local_model_weights[-1]]
                local_masks = clients_masks_info[s_idx]

                if args.model == 'ibpmlp':
                    local_assembly = IBPMLP_V3(model_arch,local_model_weights,global_model.a_prior,global_model.lambda_post,global_model.lambda_prior, global_model.p_threshold,local_bayes)
                    local_assembly.masks = local_masks

                elif args.model == 'ibpcnn':
                    local_assembly = IBPCNN_V3(model_arch,conv_arch,local_model_weights,global_model.a_prior,global_model.lambda_post,global_model.lambda_prior, global_model.p_threshold,local_bayes)



            else:

                if args.model == 'ibpmlp':
                    local_assembly = IBPMLP_V3(model_arch,local_model_weights,global_model.a_prior,global_model.lambda_post,global_model.lambda_prior, global_model.p_threshold,None,global_forward=True)

                elif args.model == 'ibpcnn':
                    local_assembly = IBPCNN_V3(model_arch,conv_arch,local_model_weights,global_model.a_prior,global_model.lambda_post,global_model.lambda_prior, global_model.p_threshold,None,global_forward=True)



            local_model = LocalUpdateIBP(args=args, dataset=train_dataset,
                                      idxs=user_groups[s_idx], logger=logger)
            acc, loss =local_model.inference(model=local_assembly)

            list_acc.append(acc)
            list_loss.append(loss)


        global_model_weights = global_model.send_weights(-1)


        if args.model == 'ibpmlp':
            global_assembly = IBPMLP_V3(model_arch,global_model_weights,global_model.a_prior,global_model.lambda_post,
                                            global_model.lambda_prior, global_model.p_threshold,None,global_forward=True)
        elif args.model == 'ibpcnn':
            global_assembly = IBPCNN_V3(model_arch,conv_arch,global_model_weights,global_model.a_prior,global_model.lambda_post,global_model.lambda_prior, global_model.p_threshold,None, global_forward=True)

        test_acc, test_loss, accuracy_class, class_cnt = test_inference_ibp_check(args, global_assembly.to(device), test_dataset)
        external_test_acc.append(test_acc)
        external_test_loss.append(test_loss)


        all_list_acc.append(list_acc)
        train_accuracy_var.append(np.var(list_acc))
        train_accuracy.append(sum(list_acc)/len(list_acc))
        print("local test acc list")
        print(train_accuracy)


        print("external test acc")
        print(external_test_acc)

        # print global training loss after every 'i' rounds
        if (epoch+1) % print_every == 0:
            print(' \nAvg Training Stats after {} global rounds:'.format(epoch+1))
            print('Training Loss : {}'.format(np.mean(np.array(train_loss))))
            print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))



    global_model_weights = global_model.send_weights(-1)

    if args.model == 'ibpmlp':
        global_assembly = IBPMLP_V3(model_arch,global_model_weights,global_model.a_prior,global_model.lambda_post,global_model.lambda_prior, global_model.p_threshold,None,global_forward=True)
    elif args.model == 'ibpcnn':
        global_assembly = IBPCNN_V3(model_arch,conv_arch,global_model_weights,global_model.a_prior,global_model.lambda_post,global_model.lambda_prior, global_model.p_threshold,None,global_forward=True)


    test_acc, test_loss,accuracy_class_global,class_cnt = test_inference_ibp_check(args, global_assembly.to(device), test_dataset)

    print(' \n Results after {} global rounds of training:'.format(args.epochs))
    print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
    print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))



    print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))

