import argparse

from solo.args.dataset import custom_dataset_args, dataset_args


def parse_args_lista() -> argparse.Namespace:
    """Parses arguments for LISTA autoencoder.

    Returns:
        argparse.Namespace: a namespace containing all args needed for LISTA autoencoder.
    """

    parser = argparse.ArgumentParser()

    # add knn args
    parser.add_argument("--pretrained_checkpoint_dir", type=str)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--num_workers", type=int, default=10)

    # add lista args
    parser.add_argument("--lista_hidden_dim", type=int, default=512)
    parser.add_argument("--lista_num_layers", type=int, default=6)
    parser.add_argument("--lista_epochs", type=int, default=10)
    parser.add_argument("--lambda_l1", type=float, default=1e-3)
    parser.add_argument("--lista_lr", type=float, default=1e-3)

    # add wandb args
    parser.add_argument("--wandb_entity", type=str, default="wandb-entity")
    parser.add_argument("--wandb_project", type=str, default="lista-autoencoder")
    parser.add_argument("--wandb_name", type=str, default="lista-autoencoder")
    parser.add_argument("--use_wandb", type=eval, choices=[True, False], default=True)

    # add shared arguments
    dataset_args(parser)
    custom_dataset_args(parser)

    # parse args
    args = parser.parse_args()

    return args

