import os
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import copy
import numpy as np
from torchvision import datasets, transforms
import torch
import torch.nn as nn
import time
import datetime
import torchvision.models as models
import random

#from resnet import resnet20
from torchvision.datasets import MNIST, FashionMNIST

from utils.sampling import mnist_iid, mnist_noniid, cifar_iid, cifar_noniid, mnist_partial_noniid
from utils.options import args_parser
from models.Update import LocalUpdate
from models.Nets import linearRegression, MLP, CNNMnist, CNNEMnist, CNNCifar, ModelCNNCifar10
from models.Fed import FedAvg, WAvg
from models.test import test_img

from data_reader import femnist, celeba

if __name__ == '__main__':
    # parse args
    args = args_parser()
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    torch.set_default_dtype(torch.float64)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    # load dataset and split users
    normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    if args.dataset == 'mnist':
        trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        dataset_train = MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
        dataset_test = MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)
        # sample users
        if args.iid:
            dict_users = mnist_iid(dataset_train, args.num_users)
        else:
            dict_users = mnist_partial_noniid(dataset_train, args.num_users, args.portion)
    elif args.dataset == 'fmnist':
        trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        dataset_train = FashionMNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
        dataset_test = FashionMNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)
        # sample users
        if args.iid:
            dict_users = mnist_iid(dataset_train, args.num_users)
        else:
            dict_users = mnist_partial_noniid(dataset_train, args.num_users, args.portion)
    elif args.dataset == 'cifar':
        trans_cifar = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ])
        dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=trans_cifar)
        dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=trans_cifar)
        if args.iid:
            dict_users = cifar_iid(dataset_train, args.num_users)
        else:
            dict_users = cifar_noniid(dataset_train, args.num_users, args.portion)
    elif args.dataset == 'cinic':
        # Download the cinic dataset before using it.
        cinic_mean = [0.47889522, 0.47227842, 0.43047404]
        cinic_std = [0.24205776, 0.23828046, 0.25874835]
        normalize = transforms.Normalize(mean=cinic_mean, std=cinic_std)
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=cinic_mean, std=cinic_std)
        ])
        trans_cifar = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ])
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=cinic_mean, std=cinic_std)
        ])
        dataset_train = datasets.ImageFolder(root='data/cinic-10/train', transform=train_transform)
        dataset_test = datasets.ImageFolder(root='data/cinic-10/test', transform=train_transform)
        if args.iid:
            dict_users = cifar_iid(dataset_train, args.num_users)
        else:
            dict_users = cifar_noniid(dataset_train, args.num_users, args.portion)
    elif args.dataset == 'femnist':
        trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        data_path = os.path.join(os.path.dirname(__file__), 'dataset_files', 'femnist')
        dataset_train = femnist.FEMNIST(data_path, train=True, download=True, transform=trans_mnist)
        dataset_test = femnist.FEMNIST(data_path, train=False, download=True, transform=trans_mnist)
        # sample users
        if args.iid:
            dict_users = mnist_iid(dataset_train, args.num_users)
        else:
            dict_users = dataset_train.get_dict_clients()
            print('** resetting number of users according to the actual value in the dataset for FEMNIST non-IID **')
            args.num_users = len(dict_users.keys())
            print('number of users:', args.num_users)
    elif args.dataset == 'celeba':
        trans_celeba = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        data_path = os.path.join(os.path.dirname(__file__), 'dataset_files', 'celeba')
        dataset_train = celeba.CelebA(data_path, train=True, download=True, transform=trans_celeba)
        dataset_test = celeba.CelebA(data_path, train=False, download=True, transform=trans_celeba)
        # sample users
        if args.iid:
            raise Exception('iid case not implemented')
        else:
            dict_users = dataset_train.get_dict_clients()
            print('** resetting number of users according to the actual value in the dataset for CelebA non-IID **')
            args.num_users = len(dict_users.keys())
            print('number of users:', args.num_users)
    else:
        exit('Error: unrecognized dataset')
    img_size = dataset_train[0][0].shape

    # build model
    if args.model == 'cnn' and args.dataset == 'cifar':
        net_glob = ModelCNNCifar10().to(args.device)

    elif args.model == 'vgg':
        net_glob = models.vgg16(pretrained=True).to(args.device)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        net_glob = CNNMnist(args=args).to(args.device)
    elif args.model == 'cnn' and args.dataset == 'femnist':
        net_glob = CNNEMnist(args=args).to(args.device)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        net_glob = MLP(dim_in=len_in, dim_hidden=args.width, dim_out=args.num_classes)
        net_glob = net_glob.to(args.device)
    elif args.model == 'linear':
        len_in = 1
        for x in img_size:
            len_in *= x
        net_glob = linearRegression(inputsize=len_in, outputsize=args.num_classes).to(args.device)
    else:
        exit('Error: unrecognized model')
    print(net_glob)



    # intializa and copy weights
    w_glob = net_glob.state_dict()
    w_glob_old = copy.deepcopy(w_glob)
  

    # initialization
    loss_train = []
    acc_train = []
    test_acc_train = []
    test_loss_train = []
    cv_loss, cv_acc = [], []
    val_loss_pre, counter = 0, 0
    net_best = None
    best_loss = None
    val_acc_list, net_list = [], []
    glob_lr = args.lg

    # params
    num_sampled = int(args.num_users * args.frac)



    # training

    # filename
    f_l1 = args.dataset + '-' + str(args.local_period) + '-' + str(args.frac) + '-'+ str(glob_lr) + '-'+ str(args.lr) + '-loss.txt'
    f_a1 = args.dataset + '-' + str(args.local_period) + '-' + str(args.frac) + '-'+ str(glob_lr) + '-'+ str(args.lr) + '-acc.txt'
    f_l2 = args.dataset + '-' + str(args.local_period) + '-' + str(args.frac) + '-'+ str(glob_lr) + '-'+ str(args.lr) + '-test_loss.txt'
    f_a2 = args.dataset + '-' + str(args.local_period) + '-' + str(args.frac) + '-'+ str(glob_lr) + '-'+ str(args.lr) + '-test_acc.txt'

    # stime = time.time()

    # pre-processing for sampling


    for iters in range(args.epochs):


        # random sampling
        sampled_clients = np.array([])
        sampled_clients = np.random.choice(args.num_users, size=num_sampled,
                                           replace=False)  # shuffling #non-iid + iid data
        lweight = np.ones(num_sampled)/num_sampled

        # local updates
        for j in range(0, num_sampled):
            idx = sampled_clients[j]

            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], iters=args.local_period)  # diff
            net_local = copy.deepcopy(net_glob).to(args.device)
            w = local.train(net=net_local, c_local=0, c_global=0)

            for k in w.keys():
                w_glob[k] = w_glob[k] + glob_lr * (w[k] - w_glob_old[k]) * lweight[j]

        # global learning rate
        w_glob_old = copy.deepcopy(w_glob)

        # copy weight to net_glob
        net_glob.load_state_dict(copy.deepcopy(w_glob))

        


        # compute training/test accuracy/loss
        if (iters + 1) % args.interval == 0:
            with torch.no_grad():
                acc_avg, loss_avg = test_img(net_glob, dataset_train, args)
                acc_test, loss_test = test_img(net_glob, dataset_test, args)
            print('Round {:3d}, Training loss {:.3f}'.format(iters, loss_avg), flush=True)
            print('Round {:3d}, Training acc {:.3f}'.format(iters, acc_avg), flush=True)
            print('Round {:3d}, Test loss {:.3f}'.format(iters, loss_test), flush=True)
            print('Round {:3d}, Test acc {:.3f}'.format(iters, acc_test), flush=True)


            # write into files
            with open(f_l1, 'a') as l1, open(f_a1, 'a') as a1, open(f_l2, 'a') as l2, open(f_a2, 'a') as a2:
                l1.write(str(loss_avg))
                l1.write('\n')
                a1.write(str(acc_avg))
                a1.write('\n')
                l2.write(str(loss_test))
                l2.write('\n')
                a2.write(str(acc_test))
                a2.write('\n')


