import os

from tqdm import tqdm

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch import nn
from torch.nn import Parameter
from torch.optim.lr_scheduler import MultiStepLR

import numpy as np
import copy

from datetime import datetime

from utils.options import args_parser
from utils.train_utils import get_model, get_data
from utils.tinyimagenet import TinyImageNet
from models.Update import LocalUpdate

## test global accuracy, based on the test dataset with all classes
def test(args, model, dataloader):
    model.eval()
    test_loss = 0
    correct = 0
    for batch_idx, (images, labels) in enumerate(dataloader):
        images, labels = images.to(args.device), labels.to(args.device)
        log_probs = model(images)
        test_loss += F.cross_entropy(log_probs, labels, reduction='sum').item()
        y_pred = log_probs.data.max(1, keepdim=True)[1]
        correct += y_pred.eq(labels.data.view_as(y_pred)).long().cpu().sum()
    
    accuracy = 100.00 * float(correct) / len(dataloader.dataset)
    return accuracy

## FL training
def train_fl(args, net_glob, dataset, dict_users_train, w_locals, new_w_locals):
    loss_locals = []
    m = max(int(args.frac * args.num_users), 1)

    idxs_users = np.random.choice(range(args.num_users), m, replace=False)

    for ind, idx in enumerate(tqdm(idxs_users)):
        
        local = LocalUpdate(args=args, dataset=dataset, idxs=dict_users_train[idx][:args.m_tr])
        net_local = copy.deepcopy(net_glob)
        w_local = net_local.state_dict()
        
        net_local.load_state_dict(w_local)
        w_local, loss = local.train(net=net_local.to(args.device), lr=args.lr)
        
        loss_locals.append(copy.deepcopy(loss))
        new_w_locals[idx] = copy.deepcopy(w_local)

    loss_avg = sum(loss_locals) / len(loss_locals)

    return loss_avg, new_w_locals, idxs_users

## model aggregation
def fedavg(net_glob, w_locals, idxs_users):
    w_glob = {}
    keys = net_glob.state_dict().keys()

    for k, key in enumerate(keys):
        w_locals_avg = 0
        for idx in idxs_users:
            w_locals_avg += w_locals[idx][key]
        w_glob[key] = copy.deepcopy(w_locals_avg) / len(idxs_users) 
    
    net_glob.load_state_dict(w_glob)
    return net_glob

if __name__ == '__main__':
    # parse args
    args = args_parser()
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')

    # random seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # dataset for global accuracy
    if args.dataset == 'cifar100':
        trans_cifar100_train = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor(),
                                          transforms.Normalize(mean=[0.507, 0.487, 0.441],
                                                               std=[0.267, 0.256, 0.276])])
        trans_cifar100_val = transforms.Compose([transforms.ToTensor(),
                                                transforms.Normalize(mean=[0.507, 0.487, 0.441],
                                                                    std=[0.267, 0.256, 0.276])])
        dataset_train = datasets.CIFAR100('data/cifar100', train=True, download=True, transform=trans_cifar100_train)
        dataset_test = datasets.CIFAR100('data/cifar100', train=False, download=True, transform=trans_cifar100_val)
        train_dataloader = DataLoader(dataset_train, batch_size=256, shuffle=True)
        test_dataloader = DataLoader(dataset_test, batch_size=256, shuffle=True)
    elif args.dataset == 'cifar10':
        trans_cifar100_train = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor(),
                                          transforms.Normalize(mean=[0.507, 0.487, 0.441],
                                                               std=[0.267, 0.256, 0.276])])
        trans_cifar100_val = transforms.Compose([transforms.ToTensor(),
                                                transforms.Normalize(mean=[0.507, 0.487, 0.441],
                                                                    std=[0.267, 0.256, 0.276])])
        dataset_train = datasets.CIFAR10('data/cifar10', train=True, download=True, transform=trans_cifar100_train)
        dataset_test = datasets.CIFAR10('data/cifar10', train=False, download=True, transform=trans_cifar100_val)
        train_dataloader = DataLoader(dataset_train, batch_size=256, shuffle=True)
        test_dataloader = DataLoader(dataset_test, batch_size=256, shuffle=True)
    elif args.dataset == 'tinyimagenet':
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        dataset_train = TinyImageNet('data/tinyimagenet', split='train', download=True, transform=transform_train)
        dataset_test = TinyImageNet('data/tinyimagenet', split='val', download=True, transform=transform_test)
        train_dataloader = DataLoader(dataset_train, batch_size=256, shuffle=True)
        test_dataloader = DataLoader(dataset_test, batch_size=256, shuffle=True)
    else:
        exit('Error: unrecognized dataset')

    ## preparation for FL
    dataset_train_fl, dataset_test_fl, dict_users_train, dict_users_test = get_data(args)
    for idx in dict_users_train.keys():
        np.random.shuffle(dict_users_train[idx])
    net_glob_fl = get_model(args)
    
    # generate list of local models for each user
    w_locals = {}
    for user in range(args.num_users):
        w_local_dict = {}
        for key in net_glob_fl.state_dict().keys():
            w_local_dict[key] =net_glob_fl.state_dict()[key]
        w_locals[user] = w_local_dict

    all_global_accuracy = []
    for iter in range(args.epochs+1):
        
        new_w_locals = {}
        loss_avg, new_w_locals, idxs_users = train_fl(args, net_glob_fl, dataset_train_fl, dict_users_train, w_locals, new_w_locals)

        net_glob_fl = fedavg(net_glob_fl, new_w_locals, idxs_users)

        for k in new_w_locals:
            w_locals[k] = copy.deepcopy(new_w_locals[k])
        print("{} epoch, fl loss: {:.3f}".format(iter, loss_avg))
        
        if iter % args.test_freq == 0:
            with torch.no_grad():
                accuracy_global = test(args, net_glob_fl, test_dataloader)
                print("{} epoch, global accuracy: {:.3f}".format(iter, accuracy_global))
                all_global_accuracy.append(accuracy_global)