import argparse
from GAN.args import add_dict_to_argparser


def create_argparser():
    defaults = dict(
        batch_size=3,
        image_size=256,
        n_channels=64,
        n_blocks=5,
        input_ch_num=12,
        output_ch_num=9,
        netD_ch_num=3,
        checkpoint=3,
        class_cond = False,

        # lamda=20,  # mse weight
        # nu=0.5,  # ssim weight
        # alpha=0.1,  # dis weight
        loss_func = {'MSE': 0.0, 'L1': 10.0, 'dis': 0.1, 'SSIM': 0.5},
        lr_G=1e-4,
        lr_D=1e-5,

        # data loading
        n_threads=8,
        crop_size=256,
        max_epoch=500,
        epoch_len=500,
        data_queue_len=10000,
        patch_per_tile=10,
        color_space="RGB",

        depth=2,
        wavelet_channel=3,

        large_size = 256,
        small_size = 64,
    )
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser