#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6


import os
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm
import torch
from tensorboardX import SummaryWriter

from options import args_parser
from update import LocalUpdate
# from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar, LogisticRegression
from models import *
from stacked_lstm import LSTM
from stacked_lstm_s import LSTM_s
from utils import get_dataset, average_weights, exp_details
from pathlib import Path
# class Namespace:
#     def __init__(self, **kwargs):
#         self.__dict__.update(kwargs)

if __name__ == '__main__':
    start_time = time.time()

    # define paths
    path_project = os.path.abspath('..')
    logger = SummaryWriter('../logs')

    args = args_parser()
#     args = Namespace(dataset='femnist', algorithm='FedAvg', epochs=50, frac=0.2,
#                      gpu=None, prox_param=0.1, local_bs=32, local_ep=20, lr=0.01,
#                      delta=1e-5, sigma=1,
#                      model='logistic', momentum=0.5, seed=1,
#                      input_dim=20, hidden_dim=40, upcycled_param=0.3)


    print(args)

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        device = 'cuda'
#         device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = 'cpu'

    # load dataset and user groups
    train_datasets, test_dataset, user_groups, group_ws = get_dataset(args)

    # BUILD MODEL
    if args.dataset == 'femnist':
#         global_model = CNNFashion_Mnist(args=args)
#         args.model = 'CNN'
        global_model = Twolayer(args=args)
        args.feature_len = 28*28
#         args.num_users = 50
    if args.dataset == 'femnist_private':
        global_model = Twolayer_private(args=args)
        args.feature_len = 10*10
#         args.num_users = 8
    if "synthetic" in args.dataset:
        global_model = LogisticRegression(args=args)
        args.feature_len = 20
#         args.num_users = 30
    if args.dataset == "sent140":
        global_model = LSTM(device=device)
        args.feature_len = 25
    if args.dataset == "shakespeare":
        global_model = LSTM_s(device=device)
        args.feature_len = 80
        
    args.num_users = len(train_datasets.keys())
#     print(args.num_users)

    exp_details(args)
    # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()

    # copy weights
    global_weights = global_model.state_dict()

    list_acc, list_loss = [], []
    global_model.eval()

    s1 = time.time()

    for c in range(args.num_users):
        local_model = LocalUpdate(args=args, dataset=train_datasets[c], logger=logger)
        acc, loss = local_model.inference(model=global_model)
        list_acc.append(acc)
        list_loss.append(loss)
    train_accuracy = sum([list_acc[key] * group_ws[key] for key in group_ws.keys()])
    train_loss = sum([list_loss[key] * group_ws[key] for key in group_ws.keys()])

    test_model = LocalUpdate(args=args, dataset=test_dataset, logger=logger)
    test_accuracy, test_loss = test_model.inference(model=global_model)

    tqdm.write('At round 0 accuracy: {}'.format(test_accuracy))
    tqdm.write('At round 0 loss: {}'.format(test_loss))
    tqdm.write('At round 0 training accuracy: {}'.format(train_accuracy))
    tqdm.write('At round 0 training loss: {}'.format(train_loss))
    
    seeds = np.arange(args.epochs)

    if args.algorithm == 'Upcycled':
        for epoch in range(args.epochs):
            if epoch % 2 == 0:
                prev_model = copy.deepcopy(global_model)
                prev_weights = global_model.state_dict()
                local_weights, local_losses = [], []

                global_model.train()
                np.random.seed(seeds[epoch])
                m = max(int(args.frac * args.num_users), 1)
                
                idxs_users = np.random.choice(range(args.num_users), m, replace=False)

                for idx in idxs_users:
#                     tqdm.write("    --user {} updating".format(idx))
                    local_model = LocalUpdate(args=args, dataset=train_datasets[idx], logger=logger)
                    w, _ = local_model.update_weights(model=copy.deepcopy(global_model))
                    local_weights.append((copy.deepcopy(w), group_ws[idx]))
                # update global weights
                global_weights = average_weights(local_weights, global_weights)

                difference = copy.deepcopy(global_weights)
                for key, value in global_weights.items():
                    difference[key] = value - prev_weights[key]

            else:
                for key, value in global_weights.items():
                    if 'num_batches_tracked' in key:
                        continue
                    global_weights[key] += args.upcycled_param * difference[key]

            # update global weights
            global_model.load_state_dict(global_weights)

            # Calculate avg training accuracy over all users at every epoch
            list_acc, list_loss = [], []
            global_model.eval()
            for c in range(args.num_users):
                local_model = LocalUpdate(args=args, dataset=train_datasets[c], logger=logger)
                acc, loss = local_model.inference(model=global_model)
                list_acc.append(acc)
                list_loss.append(loss)
            train_accuracy = sum([list_acc[key]*group_ws[key] for key in group_ws.keys()])
            train_loss = sum([list_loss[key]*group_ws[key] for key in group_ws.keys()])
            test_model = LocalUpdate(args=args, dataset=test_dataset, logger=logger)
            test_accuracy, test_loss = test_model.inference(model=global_model)
            tqdm.write('At round {} accuracy: {}'.format(epoch + 1, test_accuracy))
            tqdm.write('At round {} loss: {}'.format(epoch + 1, test_loss))
            tqdm.write('At round {} training accuracy: {}'.format(epoch + 1, train_accuracy))
            tqdm.write('At round {} training loss: {}'.format(epoch+1, train_loss))

    else:
        for epoch in range(args.epochs):
            local_weights, local_losses = [], []
            global_model.train()
            np.random.seed(seeds[epoch])
            torch.manual_seed(seeds[epoch])
            m = max(int(args.frac * args.num_users), 1)
            idxs_users = np.random.choice(range(args.num_users), m, replace=False)
            for idx in idxs_users:
#                 tqdm.write("    --user {} updating".format(idx))
                local_model = LocalUpdate(args=args, dataset=train_datasets[idx], logger=logger)
                w, flag = local_model.update_weights(model=copy.deepcopy(global_model))
#                 print(idx, flag)
                if args.algorithm == "FedAvg" and (flag is False):
                    continue
                local_weights.append((copy.deepcopy(w), group_ws[idx]))
            if len(local_weights) >= 1:
                global_weights = average_weights(local_weights, global_weights)
                global_model.load_state_dict(global_weights)

            list_acc, list_loss = [], []
            global_model.eval()
            for c in range(args.num_users):
                local_model = LocalUpdate(args=args, dataset=train_datasets[c], logger=logger)
                acc, loss = local_model.inference(model=global_model)
                list_acc.append(acc)
                list_loss.append(loss)
            train_accuracy = sum([list_acc[key] * group_ws[key] for key in group_ws.keys()])
            train_loss = sum([list_loss[key] * group_ws[key] for key in group_ws.keys()])
            test_model = LocalUpdate(args=args, dataset=test_dataset, logger=logger)
            test_accuracy, test_loss = test_model.inference(model=global_model)
            tqdm.write('At round {} accuracy: {}'.format(epoch + 1, test_accuracy))
            tqdm.write('At round {} loss: {}'.format(epoch + 1, test_loss))
            tqdm.write('At round {} training accuracy: {}'.format(epoch + 1, train_accuracy))
            tqdm.write('At round {} training loss: {}'.format(epoch + 1, train_loss))

    s2 = time.time()
    tqdm.write('Time: {}'.format(s2 - s1))

#"femnist_upcycled_mu_up_delta_sigma"