import argparse
import os
#from solver import Solver
from torch.backends import cudnn
from data_loader import get_loader
import pdb

def str2bool(v):
    return v.lower() in ('true')

def main(config):

    svhn_loader, mnist_loader, svhn_test_loader, mnist_test_loader, usps_loader, usps_test_loader,mnistm_loader, mnistm_test_loader = get_loader(config)

    if config.geometry == 0:
        from solver_MI2 import Solver
    elif config.geometry == 1:
        from solver_MI_cycle2 import Solver
    elif config.geometry == 2:
        from solver_share_MI_vf import Solver
    elif config.geometry == 3:
        from solver_share_MI_rot import Solver
    elif config.geometry == 4:
        from solver_MI import Solver

    # train loader
    if config.exp_id == 0:
        source_train = svhn_loader
        target_train = mnist_loader
    else:
        source_train = mnistm_loader
        target_train = mnist_loader
    solver = Solver(config, source_train, target_train)
    cudnn.benchmark = True

    # create directories if not exist
    if not os.path.exists(config.model_path):
        os.makedirs(config.model_path)
    if not os.path.exists(config.sample_path):
        os.makedirs(config.sample_path)

    if config.exp_id == 0:
        source_test = svhn_test_loader
        target_test = mnist_test_loader
    else:
        source_test = mnistm_test_loader
        target_test = mnist_test_loader
    if config.mode == 'train':
        # test_loader here is just for test.
        solver.train(source_test, target_test)
    elif config.mode == 'test':
        # test_loader here is just for test.
        solver.test(source_test, target_test)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    
    # 0:distanceGAN, 1:gc_rot, 2:gc_vf, 3:MIGAN
    parser.add_argument('--geometry', type=int, default=4)
    # model hyper-parameters
    parser.add_argument('--image_size', type=int, default=32)
    parser.add_argument('--g_conv_dim', type=int, default=64)
    parser.add_argument('--d_conv_dim', type=int, default=64)
    parser.add_argument('--use_reconst_loss', required=False, type=str2bool, default=False)
    parser.add_argument('--use_distance_loss', required=False, type=str2bool, default=False)
    parser.add_argument('--num_classes', type=int, default=10)
    
    # training hyper-parameters
    parser.add_argument('--train_iters', type=int, default=40000)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--lr', type=float, default=0.0002)
    parser.add_argument('--beta1', type=float, default=0.5)
    parser.add_argument('--beta2', type=float, default=0.999)
    parser.add_argument('--lambda_distance_A', type=float, default=0.05)
    parser.add_argument('--lambda_distance_B', type=float, default=0.1)
    parser.add_argument('--use_self_distance', required=False, type=str2bool, default=False)
    parser.add_argument('--use_MIR', required=False, type=str2bool, default=True)
    parser.add_argument('--max_items', type=int, default=400)
    parser.add_argument('--unnormalized_distances', required=False, type=str2bool, default=False)
    parser.add_argument('--lambda_gc', type=float, default=2.0)
    parser.add_argument('--lambda_MI', type=float, default=2.0)

    parser.add_argument('--ndf', type=int, default=64, help='number of discriminator filters')
    parser.add_argument('--leakiness', type=float, default=0.2, help='leaky relu leakiness')
    parser.add_argument('--D_conv_block_size', type=int, default=1, help='discriminator conv block size')
    parser.add_argument('--D_projection_size', type=int, default=4, help='discriminator image size after conv layers')
    parser.add_argument('--D_keep_prob', type=float, default=0.9, help='dropout keep probability')
    parser.add_argument('--D_noise_mean', type=float, default=0.0, help='discriminator external noise mean')
    parser.add_argument('--D_noise_stddev', type=float, default=0.2, help='discriminator external noise stddev')
    parser.add_argument('--ngf', type=int, default=64, help='number of generator filters')
    parser.add_argument('--G_residual_blocks', type=int, default=6, help='generator number of residual blocks')
    parser.add_argument('--G_noise_channels', type=int, default=1, help='generator number of noise channels')
    parser.add_argument('--G_noise_dim', type=int, default=10, help='generator noise dimension')
    parser.add_argument('--classifier_id', type=int, default=0)  ## 0 for usual 1 for pixelda
    # misc
    parser.add_argument('--mode', type=str, default='train')
    parser.add_argument('--model_path', type=str, default='./models')
    parser.add_argument('--sample_path', type=str, default='./samples')
    parser.add_argument('--mnist_path', type=str, default='./mnist')
    parser.add_argument('--log_path', type=str, default='./logs')
    parser.add_argument('--svhn_path', type=str, default='./svhn')
    parser.add_argument('--log_step', type=int , default=10)
    parser.add_argument('--sample_step', type=int , default=500)
    #0：s2m 1:u2m 2:m2u
    parser.add_argument('--exp_id', type=int , default=0)


    config = parser.parse_args()
    print(config)
    main(config)
