import argparse
import torch.nn as nn
import pickle as pk
import os

from datas import get_feddata
from trainers import fedtrainer


def make_parser(parser):
    
    parser.add_argument('--n_epoch', default=20, type=int)
    parser.add_argument('--batch_size', default=64, type=int)
    parser.add_argument('--num_workers', default=64, type=int)
    parser.add_argument('--mu', default=1e-5, type=float)
    parser.add_argument('--eta', default=1e-4, type=float)
    parser.add_argument('--gamma', default=0.9995, type=float)
    parser.add_argument('--k', default=5, type=int)
    parser.add_argument('--K', default=5, type=int)
    parser.add_argument('--f', default=2, type=int)
    parser.add_argument('--alpha', default=-1, type=float)
    parser.add_argument('--weight_decay', default=0, type=float)
    parser.add_argument('--slow_weight', default=0.01, type=float)
    # parser.add_argument('--p_sparse', default=1.0, type=float)
    parser.add_argument('--optimizer_name', default='Adam', type=str)
    parser.add_argument('--scheduler_name', default='constant', type=str)
    parser.add_argument("--cuda_devices", type=int, nargs='+', default=[0], help="CUDA device ids")
    parser.add_argument('--classifier_only', action='store_true')
    parser.add_argument('--model_name', default='google/vit-base-patch16-224', type=str)
    parser.add_argument('--dataset_name', default='cifar10', type=str)
    parser.add_argument('--port', default=4989, type=int)
    parser.add_argument('--random_seed', default=4989, type=int)
    parser.add_argument('--byzantine', default=0, type=int)
    
    parser.add_argument('--save_to', default='./', type=str)
    parser.add_argument('--comment', default='default', type=str)
    parser.add_argument('--remotehis', action='store_true')
    parser.add_argument('--binary', action='store_true')
    parser.add_argument('--onebit', action='store_true')
    
    return parser


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser = make_parser(parser)
    args = parser.parse_args()
    
    if os.path.exists(args.save_to):
        print(f'logs saving to {args.save_to}.')
    else:
        print(f'{args.save_to} does not exists, creating..')
        os.mkdir(args.save_to)
        print(f'logs saving to {args.save_to}.')
        

    dltrain, dlvalid, model, processor = get_feddata(
        random_seed=args.random_seed,
        model_name=args.model_name, 
        dataset_name=args.dataset_name, 
        alpha=args.alpha, 
        batch_size=args.batch_size, 
        num_workers=args.num_workers,
        proxy=args.port,
    )
    
    this_trainer = fedtrainer(
        mu=args.mu,
        eta=args.eta,
        # f=args.f,
        gamma=args.gamma,
        optimizer_name=args.optimizer_name,
        weight_decay=args.weight_decay,
        slow_weight=args.slow_weight,
        k=args.k,
        model=model,
        dltrain=dltrain,
        dlvalid=dlvalid,
        loss_func=nn.CrossEntropyLoss(),
        classifier_only=args.classifier_only,
        cuda_devices=args.cuda_devices,
        onebit=args.onebit,
        byzantine=args.byzantine,
    )
    
    pk.dump(vars(args), open(os.path.join(f'{args.save_to}', f'{args.comment}_args.pk'), 'wb'))
    this_trainer.logger_init()
    for epoch in range(args.n_epoch):
        if args.remotehis:
            this_trainer.epoch_valid(epoch)
            this_trainer.zo_epoch_train_proposed(epoch)
        elif args.binary:
            this_trainer.epoch_valid(epoch)
            this_trainer.zo_epoch_train_binary(epoch)
        else:
            this_trainer.epoch_valid(epoch)
            this_trainer.zo_epoch_train_baseline(epoch)
            
        df = this_trainer.logger_summary()
        pk.dump(df, open(os.path.join(f'{args.save_to}', f'{args.comment}_log.pk'), 'wb'))