import torch
import argparse

from utils import *
from graph import *
from models import *
from server import *
from dataset import *


#### ================= Open Float32 in A100 ================= ####
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
#### ================= Open ignore warining ================= ####
import warnings
warnings.filterwarnings('ignore')
#### ======================================================== ####
print("##=============================================##")
print("##     Federated Learning Simulator Starts     ##")
print("##=============================================##")

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', choices=['CIFAR10'], type=str, default='CIFAR10')                         # select dataset
parser.add_argument('--model', choices=['ResNet18'], type=str, default='ResNet18')                         # select model
parser.add_argument('--non-iid', action='store_true', default=False)                                       # activate if use heterogeneous dataset 
parser.add_argument('--split-rule', choices=['Dirichlet'], type=str, default='Dirichlet')                  # select the dataset splitting rule
parser.add_argument('--split-coef', default=0.1, type=float)                                                 # --> if Dirichlet: select the Dirichlet coefficient (i.e. 0.1, 0.3, 0.6, 1)
                                                                                                              
parser.add_argument('--active-ratio', default=0.1, type=float)                                             # select the partial participating ratio (i.e. 0.1, 0.05)
parser.add_argument('--total-client', default=100, type=int)                                               # select the total number of clients (i.e. 100, 500)
parser.add_argument('--keep-rate', default=1.0, type=float)                                                # select the ratio of active samples (drop the others)
parser.add_argument('--comm-rounds', default=1000, type=int)                                               # select the global communication rounds T
parser.add_argument('--local-epochs', default=5, type=int)                                                 # select the local interval K
parser.add_argument('--batchsize', default=20, type=int)                                                   # select the batchsize
parser.add_argument('--weight-decay', default=0.001, type=float)                                           # select the weight-decay (i.e. 0.01, 0.001)
parser.add_argument('--local-learning-rate', default=0.1, type=float)                                      # select the local learning rate (generally 0.1 expect for local-adaptive-based)
parser.add_argument('--global-learning-rate', default=1.0, type=float)                                     # select the global learning rate (generally 1.0 expect for global-adaptive-based)
parser.add_argument('--lr-decay', default=0.998, type=float)                                               # select the learning rate decay (generally 0.998 expect for proxy-based)
parser.add_argument('--seed', default=42, type=int)                                                        # select the random seed
parser.add_argument('--cuda', default=0, type=int)                                                         # select the cuda ID
parser.add_argument('--data-file', default='./', type=str)                                                 # select the path of the root of Dataset
parser.add_argument('--out-file', default='out/', type=str)                                                # select the path of the log files
parser.add_argument('--save-model', action='store_true', default=False)                                    # activate if save the model
parser.add_argument('--use-RI', action='store_true', default=False)                                        # activate if use relaxed initialization (RI)

parser.add_argument('--decentralized', action='store_true', default=False)                                 # activate decentralized training
parser.add_argument('--adjacent_mode', choices=['ring', 'full', 'star', \
                                                'grid', 'exp'], type=str, default='ring')                  # select the adjacent matrix

parser.add_argument('--beta', default=0.1, type=float)                                                     # select the coefficient for relaxed initialization 
parser.add_argument('--epsilon', default=0.01, type=float)                                                 # select the minimal value for avoiding zero-division

parser.add_argument('--method', choices=['FedAvg'], type=str, default='FedAvg')
                                         
args = parser.parse_args()
# print(args)

if torch.cuda.is_available():
    device = torch.device(args.cuda)
else:
    device = torch.device("cpu")

if __name__=='__main__':
    ### Generate IID or Heterogeneous Dataset
    args.total_client = int(args.total_client * args.keep_rate)

    if not args.non_iid:
        data_obj = DatasetObject(dataset=args.dataset, n_client=args.total_client, seed=args.seed, unbalanced_sgm=0, rule='iid',
                                     data_path=args.data_file, keep_rate=args.keep_rate)
        print("Initialize the Dataset     --->  {:s} {:s} {:d} clients".format(args.dataset, 'IID', args.total_client))
    else:
        data_obj = DatasetObject(dataset=args.dataset, n_client=args.total_client, seed=args.seed, unbalanced_sgm=0, rule=args.split_rule,
                                     rule_arg=args.split_coef, data_path=args.data_file, keep_rate=args.keep_rate)
        print("Initialize the Dataset     --->  {:s} {:s}-{:s} {:d} clients".format(args.dataset, args.split_rule, str(args.split_coef), args.total_client))
  

    ### Generate Model Function
    model_func = lambda: client_model(args.model)
    print("Initialize the Model Func  --->  {:s} model".format(args.model))
    init_model = model_func()

    total_trainable_params = sum(p.numel() for p in init_model.parameters() if p.requires_grad)
    print("                           --->  {:d} parameters".format(total_trainable_params))
    init_par_list = get_mdl_params(init_model)
    
    ### Generate Server
    server_func = None
    if args.method == 'FedAvg':
        server_func = FedAvg
    else:
        raise NotImplementedError('not implemented method yet')
    
    if not args.decentralized:
        Adjacent = None
    else:
        Adjacent = graph_generator(args.adjacent_mode, args.total_client)
    
    _server = server_func(device=device, model_func=model_func, init_model=init_model, init_par_list=init_par_list,
                          datasets=data_obj, method=args.method, args=args, Adjacent=Adjacent)
    _server.train()