import pdb
import sys
sys.path.append("../")

# configuration
from munch import Munch
from fedlab.models.mlp import MLP, MLP_prob
from fedlab.models.cnn import CNN_MNIST_prob, CNN_MNIST, CNN_CIFAR10, AlexNet_CIFAR10, AlexNet_CIFAR100, CNN_FEMNIST

from torchvision import transforms, models
from fedlab.contrib.dataset.partitioned_mnist import PartitionedMNIST
from fedlab.contrib.dataset.partitioned_femnist import PartitionedFEMNIST
from fedlab.contrib.dataset.femnist import FEMNIST

from fedlab.contrib.dataset.partitioned_cifar import PartitionCIFAR
# client
from fedlab.contrib.algorithm.fedavg import FedAvgSerialClientTrainer 

# server
from fedlab.contrib.algorithm.fedavg import FedAvgServerHandler 


import numpy as np
from fedlab.utils.functional import evaluate
from fedlab.core.standalone import StandalonePipeline

from torch import nn
from torch.utils.data import DataLoader
import torchvision
import os

np.random.seed(886) 
import random
random.seed(886)

args = Munch
args.seed = 42

# ## args example
# args.dataset = 'CIFAR_10' # FEMNIST # MNIST # CIFAR_10 # CIFAR_100
# args.iid = True
# args.alpha = 0.5 
# args.total_client = 50 
# args.preprocess = True
# args.cuda = True
# # local train configuration
# args.epochs = 5
# args.batch_size = 128 
# args.lr = 0.01 
# # global configuration
# args.com_round = 100 
# args.sample_ratio = 1.0 
# args.calib = 'ourDCA_COS' # CE # focal # LS # NLLMDCA # BS # DCA # MMCE # FLSD # LM
# # ourDCA_COS # ourDCA_LCKA # ourDCA_rbfCKA # ourMDCA_COS # ourMDCA_LCKA # ourMDCA_rbfCKA


if args.dataset == 'FEMNIST':
    model = CNN_FEMNIST()
else if args.dataset == 'MNIST':
    model = CNN_MNIST()
else if args.dataset == 'CIFAR_10':
    model = AlexNet_CIFAR10()
else: # CIFAR100
    model = models.resnet34(pretrained=False, num_classes = 100)
w_decay = args.lr * 0.1
momentum = 0.9

save_folder = "../results/{}_fedavg/".format(args.dataset)
isExist = os.path.exists(save_folder)
if not isExist:
    os.makedirs(save_folder)

if args.iid == True:
    file_path = save_folder + "TT_Client{}_Round{}_LocalE{}_Bat{}_SGDmoment_lr{}_IID_part{}_{}".format(args.total_client, args.com_round, args.epochs, args.batch_size, args.lr, args.sample_ratio, calib)
else:
    file_path = save_folder + "TT_Client{}_Round{}_LocalE{}_Bat{}_SGDmoment_lr{}_nonIID_alpha{}_part{}_{}".format(args.total_client, args.com_round, args.epochs, args.batch_size, args.lr, args.alpha, args.sample_ratio, calib)



if args.dataset == 'MNIST':
    if args.iid == True:
        mnist_train_partition = PartitionedMNIST(root="datasets/mnist/",
                                path="datasets/mnist/fedmnist0/",
                                num_clients=args.total_client,
                                partition="iid", 
                                seed=args.seed,
                                preprocess=args.preprocess,
                                download=True,
                                verbose=True,
                                transform=transforms.Compose(
                                    [transforms.ToPILImage(), transforms.ToTensor()]))
    else:
        mnist_train_partition = PartitionedMNIST(root="datasets/mnist/",
                                path="datasets/mnist/fedmnist0/",
                                num_clients=args.total_client,
                                partition="noniid-labeldir",
                                dir_alpha=args.alpha,
                                seed=args.seed,
                                preprocess=args.preprocess,
                                download=True,
                                verbose=True,
                                transform=transforms.Compose(
                                    [transforms.ToPILImage(), transforms.ToTensor()]))
    mnist_test = torchvision.datasets.MNIST(root="datasets/mnist/",
                                       train=False,
                                       transform=transforms.ToTensor())

else if args.dataset == 'FEMNIST':
    if args.iid == True:
        mnist_train_partition = PartitionedFEMNIST(root="datasets/femnist/",
                                path="datasets/femnist/fedmnist0/",
                                num_clients=args.total_client,
                                partition="iid", # this also try iid or  noniid-labeldir
                                seed=args.seed,
                                #  preprocess=args.preprocess, # download or not
                                preprocess=False,
                                download=True,
                                verbose=True,
                                transform=None)
    else:
        mnist_train_partition = PartitionedFEMNIST(root="datasets/femnist/",
                                path="datasets/femnist/fedmnist0/",
                                num_clients=args.total_client,
                                partition="noniid-labeldir", # this also try iid or  noniid-labeldir
                                dir_alpha=args.alpha,
                                seed=args.seed,
                                preprocess=False,
                                download=True,
                                verbose=True,
                                transform=None)
    mnist_test = FEMNIST(root="datasets/femnist/",
                                       train=False,
                                       transform=transforms.Compose([transforms.ToTensor()]))


else if args.dataset == 'CIFAR_10':
    TT_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])
    TT_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4940, 0.4850, 0.4504), (0.2467, 0.2429, 0.2616)),
            ])
    if args.iid == True:
        mnist_train_partition = PartitionCIFAR(root="datasets/cifar10/",
                                path="datasets/cifar10/cifar3/",
                                dataname = 'cifar10',
                                num_clients=args.total_client,
                                partition="iid", 
                                dir_alpha=args.alpha,
                                seed=args.seed,
                                preprocess=args.preprocess,
                                download=True,
                                verbose=True,
                                transform = TT_train)
    else:
        mnist_train_partition = PartitionCIFAR(root="datasets/cifar10/",
                                path="datasets/cifar10/cifar3/",
                                dataname = 'cifar10',
                                num_clients=args.total_client,
                                partition="dirichlet", 
                                balance=None,
                                dir_alpha=args.alpha,
                                seed=args.seed,
                                preprocess=args.preprocess,
                                download=True,
                                verbose=True,
                                transform = TT_train)
    mnist_test = torchvision.datasets.CIFAR10(root="datasets/cifar10/",
                                       train=False, transform = TT_test)

else:
    TT_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])
    TT_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4940, 0.4850, 0.4504), (0.2467, 0.2429, 0.2616)),
            ])
    if args.iid == True:
        mnist_train_partition = PartitionCIFAR(root="datasets/cifar100/",
                                path="datasets/cifar10/cifar3/",
                                dataname = 'cifar100',
                                num_clients=args.total_client,
                                partition="iid", 
                                dir_alpha=args.alpha,
                                seed=args.seed,
                                preprocess=args.preprocess,
                                download=True,
                                verbose=True,
                                transform = TT_train)
    else:
        mnist_train_partition = PartitionCIFAR(root="datasets/cifar100/",
                                path="datasets/cifar10/cifar3/",
                                dataname = 'cifar100',
                                num_clients=args.total_client,
                                partition="dirichlet", 
                                balance=None,
                                dir_alpha=args.alpha,
                                seed=args.seed,
                                preprocess=args.preprocess,
                                download=True,
                                verbose=True,
                                transform = TT_train)
    mnist_test = torchvision.datasets.CIFAR100(root="datasets/cifar100/",
                                       train=False, transform = TT_test)
   
test_loader = DataLoader(mnist_test, batch_size=1024)
trainer = FedAvgSerialClientTrainer(model, args.total_client, cuda=args.cuda) 
trainer.setup_dataset(mnist_train_partition)
trainer.setup_optim(args.epochs, args.batch_size, args.lr)
handler = FedAvgServerHandler(model=model, global_round=args.com_round, sample_ratio=args.sample_ratio, cuda=args.cuda)
pipeline = StandalonePipeline(handler, trainer, test_loader=test_loader, save_path = file_path)
pipeline.main()




