#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6


import json
from typing import Counter
import torch.nn.functional as F
from utils import get_dataset, average_weights, exp_details, dp_process
from models import CNNMnist, CNNCifar, ResNet18_Cifar, ResNet50_Clothing
from update import LocalUpdate
from options import args_parser
# from tensorboardX import SummaryWriter
import torch
import copy
import os
import shutil
import warnings
import datetime
import pickle as pkl

import numpy as np
from tqdm import tqdm
import wandb

warnings.filterwarnings('ignore')


class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)


def inference(model, test_loader):
    model.eval()
    test_loss = 0.0
    correct = 0.0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.cuda(), target.cuda()
            output = model(data)
            test_loss += F.cross_entropy(output, target).item()
            pred = torch.max(output, 1)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    acc = 100. * correct / len(test_loader.dataset)
    return acc, test_loss


def prepare_folders(cur_path):
    folders_util = [
        os.path.join(cur_path + '/logs', args.store_name),
        os.path.join(cur_path + '/checkpoints', args.store_name)]
    for folder in folders_util:
        if not os.path.exists(folder):
            print('creating folder ' + folder)
            os.mkdir(folder)


def save_checkpoint(state, cur_path, is_best):
    filename = '{}/{}/ckpt.pth.tar'.format(cur_path + '/checkpoints',
                                           args.store_name)
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, filename.replace('pth.tar', 'best.pth.tar'))


# def default_json(t):
#     return f'{t}'


if __name__ == '__main__':
    args = args_parser()
    args.type = 'iid' if args.iid == 1 else 'non-iid'
    args.store_name = '_'.join(
        [args.dataset, args.model, args.type, 'lr-' + str(args.lr), args.method, 'noisy-' + str(args.noise_ratio), 'alpha-' + str(args.alpha), 'num_c-' + str(args.num_users), 'seed-' + str(args.seed), 'random-' + str(args.random), 'power-' + str(args.power)])
    # cur_path = os.path.abspath(os.path.dirname(os.getcwd()))
    cur_path = './src'
    prepare_folders(cur_path)
    exp_details(args)

    wandb.init(
        project='fedpeer',
        name=args.store_name,
        config=()
    )

    wandb.run.name = args.store_name

    logger_file = open(os.path.join(cur_path + '/logs',
                       args.store_name, 'log.txt'), 'w')
    # tf_writer = SummaryWriter(log_dir=os.path.join(
    #     cur_path + '/logs', args.store_name))

    # load dataset and user groups
    train_dataset, test_dataset, user_groups, num_classes, client_openset = get_dataset(
        args)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128,
                                              shuffle=False, num_workers=4)
    # BUILD MODEL
    if args.dataset == 'mnist':
        global_model = CNNMnist(args).cuda()
        gloabl_model2 = CNNMnist(args).cuda()
    elif args.dataset == 'cifar10':
        # global_model = CNNCifar(num_classes).cuda()
        # gloabl_model2 = CNNCifar(num_classes).cuda()
        global_model = ResNet18_Cifar(num_classes).cuda()
        global_model2 = ResNet18_Cifar(num_classes).cuda()
    elif args.dataset == 'cifar100' or args.dataset == 'cifar100-N':
        global_model = ResNet18_Cifar(num_classes).cuda()
        global_model2 = ResNet18_Cifar(num_classes).cuda()
    elif args.dataset == 'clothing1m':
        global_model = ResNet50_Clothing(14).cuda()
        global_model2 = ResNet50_Clothing(14).cuda()
    elif 'cifar10-N' in args.dataset:
        global_model = ResNet18_Cifar(num_classes).cuda()
        global_model2 = ResNet18_Cifar(num_classes).cuda()
        # global_model =
        # global_model2 =

    bst_acc = -1
    description = "inference acc={:.4f}% loss={:.4f}, best_acc = {:.4f}%"
    log_record = dict()
    # for i in vars(args).keys():
    #     log_record[i] = vars(args)[i]
    log_record['config'] = vars(args)
    log_record['test_acc'] = []
    log_record['best_test'] = (-1, -1)
    log_record['users_group'] = user_groups
    # log_record['training_set'] = dict(Counter(train_dataset.targets))
    log_record['client_dist'] = client_openset
    label_cache_dist = None
    if args.method == 'peer':
        # Differential Privacy
        label_cache_idxs = []
        for idx in range(args.num_users):
            local_model = LocalUpdate(
                args=args, dataset=train_dataset, idxs=user_groups[idx], num_classes=num_classes)
            label_cache_idxs.extend(local_model.trainloader.dataset.idxs)

        # label_cache = train_dataset.targets[label_cache_idxs]
        label_cache = [train_dataset.targets[i] for i in label_cache_idxs]
        label_cache_dist = dp_process(args, label_cache, num_classes)

    Ts = []
    for i in range(args.num_users):
        Ts.append(torch.eye(num_classes))

    for epoch in tqdm(range(args.epochs)):
        local_weights = []
        local_weights2 = []

        global_model.train()
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        if args.method == 'correction' or args.method == 'revision':
            if epoch == int(args.epochs / 2):
                for idx in range(args.num_users):
                    local_model = LocalUpdate(
                        args=args, dataset=train_dataset, idxs=user_groups[idx], label_cache_dist=label_cache_dist, num_classes=num_classes)
                    Ts[idx] = local_model.estimate_T(
                        model=copy.deepcopy(global_model))

        # if args.method == 'revision':
        #     if epoch == int(args.epochs / 2):
        #         for idx in range(args.num_users):
        #             local_model = LocalUpdate(
        #                 args, train_dataset, idxs=user_groups[idx], label_cache_dist=label_cache_dist, num_classes=num_classes)
        #             T = local_model.revision_fit(
        #                 model=copy.deepcopy(global_model), filter_outlier=True)
        #             Ts[idx] = local_model.revieion_norm(T)

        for idx in idxs_users:
            local_model = LocalUpdate(args=args, dataset=train_dataset,
                                      idxs=user_groups[idx], label_cache_dist=label_cache_dist, num_classes=num_classes)
            w, w2 = local_model.update_weights(
                model=copy.deepcopy(global_model), model2=copy.deepcopy(global_model2),
                cur_epoch=epoch, T=Ts[idx])
            print(
                f'Epoch: [{epoch}]/[{args.epochs}] User: [{idx}]/[{len(idxs_users)}] Method: {args.method}')

            local_weights.append(copy.deepcopy(w))
            local_weights2.append(copy.deepcopy(w2))

        # update global weights
        global_weights = average_weights(local_weights)
        global_weights2 = average_weights(local_weights2)

        # update global weights
        global_model.load_state_dict(global_weights)
        global_model2.load_state_dict(global_weights2)

        test_acc, test_loss = inference(global_model, test_loader)

        # tf_writer.add_scalar('test_acc', test_acc, epoch)
        # tf_writer.add_scalar('test_loss', test_loss, epoch)

        output_log = 'After {} global rounds, Test acc: {}, inference loss: {}'.format(
            epoch + 1, test_acc, test_loss)
        wandb.log({
            'test_acc': test_acc,
            'test_loss': test_loss
        })

        logger_file.write(output_log + '\n')
        logger_file.flush()

        is_best = test_acc > bst_acc
        bst_acc = max(bst_acc, test_acc)
        print(description.format(test_acc, test_loss, bst_acc))

        log_record['test_acc'].append(test_acc)
        if is_best:
            log_record['best_test'] = (epoch, bst_acc)

        save_checkpoint(global_model.state_dict(), cur_path, is_best)

    cur_dt = datetime.datetime.now()
    with open(os.path.join(cur_path + '/logs', args.store_name, f'log_{cur_dt.strftime("%Y%m%d%H%M%S")}.pkl'), 'wb') as f:
        pkl.dump(log_record, f)

"""
python3 federated_main.py --model=cnn --dataset=cifar --iid=1 --epochs=300 --lr=0.01 --local_ep=5 --local_bs=32

python3 federated_main.py --model=cnn --dataset=mnist --iid=1 --epochs=100 --lr=0.01 --local_ep=5 --local_bs=32

"""
