import configargparse


def make_base_parser():
    parser = configargparse.ArgParser()
    parser.add('-c', '--my-config', is_config_file=True, help='config file path')
    # Model Set Up
    parser.add('--eta', type=float, default=0.0001, help='initial eta, WAS 0.01')

    # Data Set Up
    parser.add('--data_path', type=str, default='../data', help='Path to data')
    parser.add('--img_dim', type=int, default=320, help='image size')
    parser.add('--noise_level', type=float, default=0.01)

    # Training Set Up
    parser.add("--device", type=str, default="cpu", help="cpu or cuda")
    parser.add('--batch_size', type=int, default=8)
    parser.add('--batch_size_val', type=int, default=8)
    parser.add('--lr', type=float, default=1e-4, help='5e-5 for gradient descent')  # mnist:5e-5, fmnist:5e-4, om:1e-4
    parser.add('--lr_gamma', type=float, default=0.5)  # new lr = lr_gamma * lr
    parser.add('--continue_train', action='store_true', help='if load utils and resume training')
    parser.add('--pretrain', action='store_true', help='if load utils and resume training')
    parser.add('--eval', action='store_false', dest='train', help='Train vs eval')
    parser.add('--use_aux_loss', type=bool, default=False, help='Use aux loss')

    # Pathing
    parser.add("--path", type=str, default="../saved_models/", help="saving directory")
    parser.add('--save_path', type=str, default='../saved_models/')  # double check this later
    parser.add('--load_path', type=str, default='')
    return parser


# Dataset Specific Arguments
def add_mri_args(parser):
    parser.add('--dataset', type=str, default='MRI', help='MRI, CT or CelebA')
    parser.add('--mri_acc', type=float, default=8, help='MRI acceleration')
    parser.add('--nc', type=int, default=2, help='number of channels in an image')
    # Training
    parser.add('--n_epochs', type=int, default=200)
    parser.add('--sched_step', type=int, default=50)
    return parser


def add_celeba_args(parser):
    parser.add('--dataset', type=str, default='CelebA', help='MRI, CT or CelebA')
    parser.add('--sigma', type=float, default=5, help='variance if A is Gaussian')
    parser.add('--kernel_size', type=int, default=9, help='Kernel size if A is Gaussian')
    parser.add('--nc', type=int, default=3, help='number of channels in an image')
    # Training
    parser.add('--n_epochs', type=int, default=15)
    parser.add('--sched_step', type=int, default=5)
    return parser


def add_ct_args(parser):
    parser.add('--dataset', type=str, default='CT', help='MRI, CT or CelebA')
    parser.add('--nc', type=int, default=1, help='number of channels in an image')
    # Training
    parser.add('--n_epochs', type=int, default=15)
    parser.add('--sched_step', type=int, default=5)
    return parser


# Model Specific Arguments
def add_lu_args(parser):
    parser.add('--n_features', type=int, default=64, help='hidden channels in r')
    parser.add('--n_layers', type=int, default=17, help='Number of DnCNN layers')
    parser.add('--maxiters', type=int, default=2, help='Main max iterations')

    # Training 
    return parser


def add_luser_args(parser):
    parser.add('--maxiters', type=int, default=8, help='Main max iterations')
    parser.add('--diffW', type=bool, default=True)
    parser.add('--and_m', type=int, default=5, help='Anderson m, was 5')
    parser.add('--and_beta', type=float, default=3, help='Anderson beta, WAS 2')
    parser.add('--and_maxiters', type=int, default=40, help='Anderson max iters, WAS 20')
    parser.add('--and_tol', type=float, default=1e-6, help='Anderson tolerance, was 1e-3')
    # Training
    return parser


def add_deqip_args(parser):
    parser.add('--n_features', type=int, default=64, help='hidden channels in r')
    parser.add('--n_layers', type=int, default=17, help='Number of DnCNN layers')
    parser.add('--shared_eta', action='store_true', help='Share eta across iterations K')

    parser.add('--and_m', type=int, default=5, help='Anderson m, was 5')
    parser.add('--and_beta', type=float, default=3, help='Anderson beta, WAS 2')
    parser.add('--and_maxiters', type=int, default=40, help='Anderson max iters, WAS 20')
    parser.add('--and_tol', type=float, default=1e-6, help='Anderson tolerance, was 1e-3')
    return parser
#     parser.add('--tol', type=float, default=1e-6, help='Main tolerance, MNIST was 1e-3')
