import pdb
import sys
sys.path.append("../")

# configuration
from munch import Munch
from fedlab.models.mlp import MLP, MLP_prob

from torchvision import transforms
from fedlab.contrib.dataset.partitioned_mnist import PartitionedMNIST
# client
from fedlab.contrib.algorithm.fednova import FedNovaServerHandler, FedNovaSerialClientTrainer
# server
from fedlab.contrib.algorithm.basic_server import SyncServerHandler

import numpy as np

from fedlab.utils.functional import evaluate
from fedlab.core.standalone import StandalonePipeline

from torch import nn
from torch.utils.data import DataLoader
import torchvision
import os


args = Munch
args.seed = 42

# ## args example
# args.dataset = 'CIFAR_10' # FEMNIST # MNIST # CIFAR_10 # CIFAR_100
# args.iid = True
# args.alpha = 0.5 
# args.total_client = 50 
# args.preprocess = True
# args.cuda = True
# # local train configuration
# args.epochs = 5
# args.batch_size = 128 
# args.lr = 0.01 
# # global configuration
# args.com_round = 100 
# args.sample_ratio = 1.0 
# args.calib = 'ourDCA_COS' # CE # focal # LS # NLLMDCA # BS # DCA # MMCE # FLSD # LM
# # ourDCA_COS # ourDCA_LCKA # ourDCA_rbfCKA # ourMDCA_COS # ourMDCA_LCKA # ourMDCA_rbfCKA


if args.dataset == 'FEMNIST':
    model = CNN_FEMNIST()
else if args.dataset == 'MNIST':
    model = CNN_MNIST()
else if args.dataset == 'CIFAR_10':
    model = AlexNet_CIFAR10()
else: # CIFAR100
    model = models.resnet34(pretrained=False, num_classes = 100)
w_decay = args.lr * 0.1
momentum = 0.9

save_folder = "../results/{}_fednova/".format(args.dataset)
isExist = os.path.exists(save_folder)
if not isExist:
    os.makedirs(save_folder)

if args.iid == True:
    file_path = save_folder + "TT_Client{}_Round{}_LocalE{}_Bat{}_SGDmoment_lr{}_IID_part{}_{}".format(args.total_client, args.com_round, args.epochs, args.batch_size, args.lr, args.sample_ratio, calib)
else:
    file_path = save_folder + "TT_Client{}_Round{}_LocalE{}_Bat{}_SGDmoment_lr{}_nonIID_alpha{}_part{}_{}".format(args.total_client, args.com_round, args.epochs, args.batch_size, args.lr, args.alpha, args.sample_ratio, calib)



if args.dataset == 'MNIST':
    if args.iid == True:
        mnist_train_partition = PartitionedMNIST(root="datasets/mnist/",
                                path="datasets/mnist/fedmnist0/",
                                num_clients=args.total_client,
                                partition="iid", 
                                seed=args.seed,
                                preprocess=args.preprocess,
                                download=True,
                                verbose=True,
                                transform=transforms.Compose(
                                    [transforms.ToPILImage(), transforms.ToTensor()]))
    else:
        mnist_train_partition = PartitionedMNIST(root="datasets/mnist/",
                                path="datasets/mnist/fedmnist0/",
                                num_clients=args.total_client,
                                partition="noniid-labeldir",
                                dir_alpha=args.alpha,
                                seed=args.seed,
                                preprocess=args.preprocess,
                                download=True,
                                verbose=True,
                                transform=transforms.Compose(
                                    [transforms.ToPILImage(), transforms.ToTensor()]))
    mnist_test = torchvision.datasets.MNIST(root="datasets/mnist/",
                                       train=False,
                                       transform=transforms.ToTensor())

else if args.dataset == 'FEMNIST':
    if args.iid == True:
        mnist_train_partition = PartitionedFEMNIST(root="datasets/femnist/",
                                path="datasets/femnist/fedmnist0/",
                                num_clients=args.total_client,
                                partition="iid", # this also try iid or  noniid-labeldir
                                seed=args.seed,
                                #  preprocess=args.preprocess, # download or not
                                preprocess=False,
                                download=True,
                                verbose=True,
                                transform=None)
    else:
        mnist_train_partition = PartitionedFEMNIST(root="datasets/femnist/",
                                path="datasets/femnist/fedmnist0/",
                                num_clients=args.total_client,
                                partition="noniid-labeldir", # this also try iid or  noniid-labeldir
                                dir_alpha=args.alpha,
                                seed=args.seed,
                                preprocess=False,
                                download=True,
                                verbose=True,
                                transform=None)
    mnist_test = FEMNIST(root="datasets/femnist/",
                                       train=False,
                                       transform=transforms.Compose([transforms.ToTensor()]))


else if args.dataset == 'CIFAR_10':
    TT_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])
    TT_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4940, 0.4850, 0.4504), (0.2467, 0.2429, 0.2616)),
            ])
    if args.iid == True:
        mnist_train_partition = PartitionCIFAR(root="datasets/cifar10/",
                                path="datasets/cifar10/cifar3/",
                                dataname = 'cifar10',
                                num_clients=args.total_client,
                                partition="iid", 
                                dir_alpha=args.alpha,
                                seed=args.seed,
                                preprocess=args.preprocess,
                                download=True,
                                verbose=True,
                                transform = TT_train)
    else:
        mnist_train_partition = PartitionCIFAR(root="datasets/cifar10/",
                                path="datasets/cifar10/cifar3/",
                                dataname = 'cifar10',
                                num_clients=args.total_client,
                                partition="dirichlet", 
                                balance=None,
                                dir_alpha=args.alpha,
                                seed=args.seed,
                                preprocess=args.preprocess,
                                download=True,
                                verbose=True,
                                transform = TT_train)
    mnist_test = torchvision.datasets.CIFAR10(root="datasets/cifar10/",
                                       train=False, transform = TT_test)

else:
    TT_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])
    TT_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4940, 0.4850, 0.4504), (0.2467, 0.2429, 0.2616)),
            ])
    if args.iid == True:
        mnist_train_partition = PartitionCIFAR(root="datasets/cifar100/",
                                path="datasets/cifar10/cifar3/",
                                dataname = 'cifar100',
                                num_clients=args.total_client,
                                partition="iid", 
                                dir_alpha=args.alpha,
                                seed=args.seed,
                                preprocess=args.preprocess,
                                download=True,
                                verbose=True,
                                transform = TT_train)
    else:
        mnist_train_partition = PartitionCIFAR(root="datasets/cifar100/",
                                path="datasets/cifar10/cifar3/",
                                dataname = 'cifar100',
                                num_clients=args.total_client,
                                partition="dirichlet", 
                                balance=None,
                                dir_alpha=args.alpha,
                                seed=args.seed,
                                preprocess=args.preprocess,
                                download=True,
                                verbose=True,
                                transform = TT_train)
    mnist_test = torchvision.datasets.CIFAR100(root="datasets/cifar100/",
                                       train=False, transform = TT_test)
   

test_loader = DataLoader(mnist_test, batch_size=1024)




handler = FedNovaServerHandler(model=model, global_round=args.com_round, sample_ratio=args.sample_ratio)
handler.setup_optim(option="weighted_scale")
trainer = FedNovaSerialClientTrainer(model, args.total_client, cuda=True)
trainer.setup_optim(args.epochs, args.batch_size, args.lr)
trainer.setup_dataset(mnist_train_partition)
pipeline = StandalonePipeline(handler, trainer, test_loader=test_loader, save_path = file_path)
pipeline.main()




