#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import argparse
import copy
import datetime
import json
import os
import random
import shutil
import sys
import time
import warnings
from hps import *

import math
import torchvision.models as models
import numpy as np
from tqdm import tqdm
import pdb

from helpers.datasets import partition_data
from helpers.utils import get_dataset, mean_average_weights, DatasetSplit, KLDiv, setup_seed, test, \
    federated_average_weights
from loop_df_fl import get_model, LocalUpdate, Ensemble, args_parser
from models.nets import CNNCifar, CNNMnist, CNNCifar100, CNNPACS
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torch.utils.data.dataset import random_split

from models.resnet import resnet18
from models.vit import deit_tiny_patch16_224
from warmup_config import warmup_config
# import wandb

warnings.filterwarnings('ignore')
upsample = torch.nn.Upsample(mode='nearest', scale_factor=7)

if __name__ == '__main__':
    args = args_parser()
    print(args)
    if not torch.cuda.is_available():
        args.device = "cpu"
        print("CUDA is not available, use CPU.")
    setup_seed(args.seed)
    # pdb.set_trace()
    # BUILD MODEL
    start_time = time.time()
    global_model = get_model(args)
    init_weights = copy.deepcopy(global_model.state_dict())
    bst_acc = -1
    description = "inference acc={:.4f}% loss={:.2f}, best_acc = {:.2f}%"
    global_model.train()
    fedavg_accs = []
    client_accs = []
    if args.id == "0":
        id = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")
    else:
        id = args.id
    print("id: {}".format(id))
    time.sleep(3)
    if args.fedavgEpochs == 1:
        hps = hyperparameters_one_shot[args.dataset]
    else:
        hps = hyperparameters[args.dataset]
    fedavg_model_weights = []
    saved_model_weights_pool = []
    # ===============================================
    model_weights_pool = []
    for i in range(args.fedavgEpochs):  # FEDAVG TEST
        local_weights = []
        user_avg_weights = []
        users = []
        saved_datasets = []
        acc_list = []
        max_accs = []
        best_model_weights = []
        val_accs = []
        client_losses = []
        if i == 0:
            # Client 0 warmup
            if args.warmup_epochs != -1:
                warmup_epochs = args.warmup_epochs
            else:
                warmup_epochs = warmup_config[args.dataset][args.model][0]
            hyperparameter = hps[0]
            train_dataset, val_dataset, test_dataset, user_groups, val_user_groups, training_data_cls_counts = partition_data(
                args.dataset, args.partition, beta=args.betas, num_users=args.num_users,
                transform=hyperparameter["transform"], order=args.order)
            test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256,
                                                      shuffle=False, num_workers=4)
            local_model = LocalUpdate(args=args, dataset=train_dataset, val_dataset=val_dataset,
                                      idxs=user_groups[0], val_idxs=val_user_groups[0], test_loader=test_loader)
            training_set, valid_set = local_model.get_datasets()
            global_model.load_state_dict(init_weights)
            print("Start Warm Up")
            warmup_weights, local_acc_list, best_epoch, max_val_acc, local_loss_list = local_model.update_weights(
            copy.deepcopy(global_model), args.device, hyperparameter, local_ep=warmup_epochs, optimize=False, args=args)
            t_warmup_weights = copy.deepcopy(warmup_weights)
            model_weights_pool.append(t_warmup_weights)

        for idx in range(args.num_users):
            train_dataset, val_dataset, test_dataset, user_groups, val_user_groups, training_data_cls_counts = partition_data(
                args.dataset, args.partition, beta=args.betas, num_users=args.num_users,
                transform=hyperparameter["transform"], order=args.order)
            test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256,
                                                      shuffle=False, num_workers=4)
            for m in range(args.num_models):
                print("Now training model {} for user {}".format(m, idx))
                local_model = LocalUpdate(args=args, dataset=train_dataset, val_dataset=val_dataset,
                                          idxs=user_groups[idx], val_idxs=val_user_groups[idx], test_loader=test_loader)
                model_weights_pool, local_acc_list, best_epoch, max_val_acc, local_loss_list = local_model.update_weights_model_pool(
                    copy.deepcopy(global_model), args.device, hyperparameter, model_weights_pool, random_position=args.random_position, args=args)
            saved_model_weights_pool.extend(model_weights_pool)
            model_weights_pool = [mean_average_weights(model_weights_pool)]

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')
    torch.save(saved_model_weights_pool,
               'checkpoints/{}_{}clients_{}_{}_{}_{}.pkl'.format(args.dataset, args.num_users, args.betas,
                                                                 args.partition, args.model, id))

    global_weights = mean_average_weights(model_weights_pool)
    global_model.load_state_dict(global_weights)
    print("One-Shot MeanAvg Accuracy:")
    meanavg_test_acc, meanavg_test_loss = test(global_model, test_loader, args.device)

    model_list = []
    for i in range(len(model_weights_pool)):
        net = copy.deepcopy(global_model)
        net.load_state_dict(model_weights_pool[i])
        model_list.append(net)
    ensemble_model = Ensemble(model_list)
    print("Ensemble Accuracy:")
    ensemble_test_acc, ensemble_test_loss = test(ensemble_model, test_loader, args.device)
    max_acc = []
    for sub in acc_list:
        max_acc.append(max(sub))
    last_acc = []
    for sub in acc_list:
        last_acc.append(sub[-1])
    output = {
        "id": id,
        "seed": args.seed,
        "fedavgEpochs": args.fedavgEpochs,
        "time": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "dataset": args.dataset,
        "model": args.model,
        "num_users": args.num_users,
        "betas": args.betas,
        "partition": args.partition,
        "warmup_config": warmup_config[args.dataset],
        "file": os.path.basename(__file__),
        "meanavg_test_acc": meanavg_test_acc,
        "meanavg_test_loss": meanavg_test_loss,
        "ensemble_test_acc": ensemble_test_acc,
        "ensemble_test_loss": ensemble_test_loss,
        "hyperparameters": hps,
        "data_cls_counts": str(training_data_cls_counts),
        "args": str(args),
        "local_bs": args.local_bs,
        "validation_ratio": args.validation_ratio,
        "note": args.note,
        "order": args.order,
        "client_losses": client_losses,
        "val_accs": val_accs,
        "alpha": args.alpha,
        "beta": args.beta,
    }
    if not os.path.exists('results'):
        os.makedirs('results')
    json.dump(output, open(
        'results/{}_{}clients_{}_{}_{}.json'.format(args.dataset, args.num_users, args.betas, args.model, id), 'w'))
    output = json.dumps(output, indent=4)
    print(output)
    print("One-Shot MeanAvg Accuracy:")
    meanavg_test_acc, meanavg_test_loss = test(global_model, test_loader, args.device)
    end_time = time.time()
    print("Total Time Cost: {:.2f}s".format(end_time - start_time))
    # ===============================================

