import os
import argparse
from solver_fedavg import FedSolver
from mnist_loader import MnistRotated
import torch
from torch.backends import cudnn
from torch.utils.data import DataLoader


def str2bool(v):
    return v.lower() in ('true')


def main(config):
    # For fast training.
    cudnn.benchmark = True

    # Create directories if not exist.
    if not os.path.exists(config.model_save_dir):
        os.makedirs(config.model_save_dir)

    # Data loader.
    all_domains = ['0', '15', '30', '45', '60', config.target_domain]
    all_domains.remove(config.target_domain)
    config.source_domains = all_domains

    data_loader_dict = {} # data loader for each client
    domain_idx = {} # keep track of domain - index
    for i,domain in enumerate(all_domains):
        domain_idx[domain] = i
        data_set = MnistRotated([domain], config.target_domain,
                                '../data/', train=True, mnist_subset='0')
        # change the domain label
        data_set.train_domain = torch.ones_like(data_set.train_domain) * i
        data_loader = DataLoader(data_set, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)
        data_loader_dict[domain] = data_loader

    test_set = MnistRotated(all_domains, [config.target_domain],
                            '../data/', train=False, mnist_subset='0')
    test_loader = DataLoader(test_set, batch_size=10000, shuffle=False,
                             num_workers=config.num_workers)

    # Solver for training and testing StarGAN.
    solver = FedSolver(data_loader_dict, domain_idx, test_loader, config)

    solver.train()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # WandB configuration.
    parser.add_argument('--project', type=str, default="flinb-stargan-fl", help="name of wandb project")
    parser.add_argument('--entity', type=str, default="", help='username in wandb')

    # Model configuration.
    parser.add_argument('--arch', type=str, default=None, help='architecture name')
    parser.add_argument('--c_dim', type=int, default=5, help='dimension of domain labels (1st dataset)')
    parser.add_argument('--image_size', type=int, default=28, help='image resolution')
    parser.add_argument('--g_conv_dim', type=int, default=64, help='number of conv filters in the first layer of G')
    parser.add_argument('--d_conv_dim', type=int, default=64, help='number of conv filters in the first layer of D')
    parser.add_argument('--g_repeat_num', type=int, default=6, help='number of residual blocks in G')
    parser.add_argument('--d_repeat_num', type=int, default=4, help='number of strided conv layers in D')
    parser.add_argument('--lambda_cls', type=float, default=1, help='weight for domain classification loss')
    parser.add_argument('--lambda_rec', type=float, default=10, help='weight for reconstruction loss')
    parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty')

    # Training configuration.
    parser.add_argument('--dataset', type=str, default='RotatedMnist')
    parser.add_argument('--target_domain', type=str, default='75')
    parser.add_argument('--batch_size', type=int, default=64, help='mini-batch size')
    parser.add_argument('--num_iters', type=int, default=200000, help='number of total iterations for training D')
    parser.add_argument('--num_iters_decay', type=int, default=100000, help='number of iterations for decaying lr')
    parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for G')
    parser.add_argument('--d_lr', type=float, default=0.0001, help='learning rate for D')
    parser.add_argument('--n_critic', type=int, default=5, help='number of D updates per each G update')
    parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer')
    parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
    parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step')

    # Test configuration.
    parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step')

    # Miscellaneous.
    parser.add_argument('--num_workers', type=int, default=1)
    parser.add_argument('--mode', type=str, default='train', choices=['train', 'test'])
    parser.add_argument('--use_tensorboard', type=str2bool, default=False)
    parser.add_argument('--use_wandb', default=True, type=str2bool)
    parser.add_argument('--vis_trans', default=False, action='store_true')
    parser.add_argument('--run_name', type=str, default='notnamed')

    # Directories.
    parser.add_argument('--log_dir', type=str, default='fedstargan/logs')
    parser.add_argument('--model_save_dir', type=str, default='saved/fedstargan_model')

    # Step size.
    parser.add_argument('--sync_step', type=int, default=1)
    parser.add_argument('--log_step', type=int, default=10)
    parser.add_argument('--sample_step', type=int, default=1000)
    parser.add_argument('--model_save_step', type=int, default=10000)
    parser.add_argument('--lr_update_step', type=int, default=1000)
    parser.add_argument('--vis_step', type=int, default=10000)
    parser.add_argument('--device_name', type=str, default='cuda:1')

    config = parser.parse_args()
    print(config)
    main(config)
