import argparse
from training.params import get_default_params, ParseKwargs


def add_base_args(parser):
    parser.add_argument(
        "--train-data",
        type=str,
        default=None,
        help="Path to file(s) with training data. When using webdataset, multiple datasources can be combined using the `::` separator.",
    )
    parser.add_argument(
        "--train-data-upsampling-factors",
        type=str,
        default=None,
        help=(
            "When using multiple data sources with webdataset and sampling with replacement, this can be used to upsample specific data sources. "
            "Similar to --train-data, this should be a string with as many numbers as there are data sources, separated by `::` (e.g. 1::2::0.5) "
            "By default, datapoints are sampled uniformly regardless of the dataset sizes."
        )
    )
    parser.add_argument(
        "--val-data",
        type=str,
        default=None,
        help="Path to file(s) with validation data",
    )
    parser.add_argument(
        "--train-num-samples",
        type=int,
        default=None,
        help="Number of samples in dataset. Required for webdataset if not available in info file.",
    )
    parser.add_argument(
        "--val-num-samples",
        type=int,
        default=None,
        help="Number of samples in dataset. Useful for webdataset if not available in info file.",
    )
    parser.add_argument(
        "--dataset-type",
        choices=["webdataset", "csv", "synthetic", "auto"],
        default="auto",
        help="Which type of dataset to process."
    )
    parser.add_argument(
        "--dataset-resampled",
        default=False,
        action="store_true",
        help="Whether to use sampling with replacement for webdataset shard selection."
    )
    parser.add_argument(
        "--csv-separator",
        type=str,
        default="\t",
        help="For csv-like datasets, which separator to use."
    )
    parser.add_argument(
        "--csv-img-key",
        type=str,
        default="filepath",
        help="For csv-like datasets, the name of the key for the image paths."
    )
    parser.add_argument(
        "--csv-caption-key",
        type=str,
        default="title",
        help="For csv-like datasets, the name of the key for the captions."
    )
    parser.add_argument(
        "--imagenet-val",
        type=str,
        default=None,
        help="Path to imagenet val set for conducting zero shot evaluation.",
    )
    parser.add_argument(
        "--imagenet-v2",
        type=str,
        default=None,
        help="Path to imagenet v2 for conducting zero shot evaluation.",
    )
    parser.add_argument(
        "--logs",
        type=str,
        default="./logs/",
        help="Where to store tensorboard logs. Use None to avoid storing logs.",
    )
    parser.add_argument(
        "--log-local",
        action="store_true",
        default=False,
        help="log files on local master, otherwise global master only.",
    )
    parser.add_argument(
        "--name",
        type=str,
        default=None,
        help="Optional identifier for the experiment when storing logs. Otherwise use current time.",
    )
    parser.add_argument(
        "--workers", type=int, default=4, help="Number of dataloader workers per GPU."
    )
    parser.add_argument(
        "--batch-size", type=int, default=64, help="Batch size per GPU."
    )
    parser.add_argument(
        "--epochs", type=int, default=32, help="Number of epochs to train for."
    )
    parser.add_argument(
        "--epochs-cooldown", type=int, default=None,
        help="When scheduler w/ cooldown used, perform cooldown from total_epochs - cooldown_epochs onwards."
    )
    parser.add_argument("--lr", type=float, default=None, help="Learning rate.")
    parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.")
    parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.")
    parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.")
    parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")
    parser.add_argument(
        "--warmup", type=int, default=10000, help="Number of steps to warmup for."
    )
    parser.add_argument(
        "--use-bn-sync",
        default=False,
        action="store_true",
        help="Whether to use batch norm sync.")
    parser.add_argument(
        "--skip-scheduler",
        action="store_true",
        default=False,
        help="Use this flag to skip the learning rate decay.",
    )
    parser.add_argument(
        "--lr-scheduler",
        type=str,
        default='cosine',
        help="LR scheduler. One of: 'cosine', 'const' (constant), 'const-cooldown' (constant w/ cooldown). Default: "
             "cosine",
    )
    parser.add_argument(
        "--lr-cooldown-end", type=float, default=0.0,
        help="End learning rate for cooldown schedule. Default: 0"
    )
    parser.add_argument(
        "--lr-cooldown-power", type=float, default=1.0,
        help="Power for polynomial cooldown schedule. Default: 1.0 (linear decay)"
    )
    parser.add_argument(
        "--save-frequency", type=int, default=1, help="How often to save checkpoints."
    )
    parser.add_argument(
        "--save-most-recent",
        action="store_true",
        default=False,
        help="Always save the most recent model trained to epoch_latest.pt.",
    )
    parser.add_argument(
        "--zeroshot-frequency", type=int, default=1, help="How often to run zero shot."
    )
    parser.add_argument(
        "--val-frequency", type=int, default=1, help="How often to run evaluation with val data."
    )
    parser.add_argument(
        "--resume",
        default=None,
        type=str,
        help="path to latest checkpoint (default: none)",
    )
    parser.add_argument(
        "--precision",
        choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"],
        default="amp",
        help="Floating point precision."
    )
    parser.add_argument(
        "--model",
        type=str,
        default="RN50",
        help="Name of the vision backbone to use.",
    )
    parser.add_argument(
        "--pretrained",
        default='',
        type=str,
        help="Use a pretrained CLIP model weights with the specified tag or file path.",
    )
    parser.add_argument(
        "--pretrained-image",
        default=False,
        action='store_true',
        help="Load imagenet pretrained weights for image tower backbone if available.",
    )
    parser.add_argument(
        "--lock-image",
        default=False,
        action='store_true',
        help="Lock full image tower by disabling gradients.",
    )
    parser.add_argument(
        "--lock-image-unlocked-groups",
        type=int,
        default=0,
        help="Leave last n image tower layer groups unlocked.",
    )
    parser.add_argument(
        "--lock-image-freeze-bn-stats",
        default=False,
        action='store_true',
        help="Freeze BatchNorm running stats in image tower for any locked layers.",
    )
    parser.add_argument(
        '--image-mean', type=float, nargs='+', default=None, metavar='MEAN',
        help='Override default image mean value of dataset')
    parser.add_argument(
        '--image-std', type=float, nargs='+', default=None, metavar='STD',
        help='Override default image std deviation of of dataset')
    parser.add_argument('--aug-cfg', nargs='*', default={}, action=ParseKwargs)
    parser.add_argument(
        "--grad-checkpointing",
        default=False,
        action='store_true',
        help="Enable gradient checkpointing.",
    )
    parser.add_argument(
        "--local-loss",
        default=False,
        action="store_true",
        help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)"
    )
    parser.add_argument(
        "--gather-with-grad",
        default=False,
        action="store_true",
        help="enable full distributed gradient for feature gather"
    )
    parser.add_argument(
        '--force-image-size', type=int, nargs='+', default=None,
        help='Override default image size'
    )
    parser.add_argument(
        "--force-quick-gelu",
        default=False,
        action='store_true',
        help="Force use of QuickGELU activation for non-OpenAI transformer models.",
    )
    parser.add_argument(
        "--force-patch-dropout",
        default=None,
        type=float,
        help="Override the patch dropout during training, for fine tuning with no dropout near the end as in the paper",
    )
    parser.add_argument(
        "--force-custom-text",
        default=False,
        action='store_true',
        help="Force use of CustomTextCLIP model (separate text-tower).",
    )
    parser.add_argument(
        "--torchscript",
        default=False,
        action='store_true',
        help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'",
    )
    parser.add_argument(
        "--trace",
        default=False,
        action='store_true',
        help="torch.jit.trace the model for inference / eval only",
    )
    parser.add_argument(
        "--accum-freq", type=int, default=1, help="Update the model every --acum-freq steps."
    )
    # arguments for distributed training
    parser.add_argument(
        "--dist-url",
        default="env://",
        type=str,
        help="url used to set up distributed training",
    )
    parser.add_argument(
        "--dist-backend", default="nccl", type=str, help="distributed backend"
    )
    parser.add_argument(
        "--report-to",
        default='',
        type=str,
        help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']"
    )
    parser.add_argument(
        "--wandb-notes",
        default='',
        type=str,
        help="Notes if logging with wandb"
    )
    parser.add_argument(
        "--wandb-project-name",
        type=str,
        default='open-clip',
        help="Name of the project if logging with wandb.",
    )
    parser.add_argument(
        "--debug",
        default=False,
        action="store_true",
        help="If true, more information is logged."
    )
    parser.add_argument(
        "--copy-codebase",
        default=False,
        action="store_true",
        help="If true, we copy the entire base on the log directory, and execute from there."
    )
    parser.add_argument(
        "--horovod",
        default=False,
        action="store_true",
        help="Use horovod for distributed training."
    )
    parser.add_argument(
        "--ddp-static-graph",
        default=False,
        action='store_true',
        help="Enable static graph optimization for DDP in PyTorch >= 1.11.",
    )
    parser.add_argument(
        "--no-set-device-rank",
        default=False,
        action="store_true",
        help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc)."
    )
    parser.add_argument(
        "--seed", type=int, default=0, help="Default random seed."
    )
    parser.add_argument(
        "--grad-clip-norm", type=float, default=None, help="Gradient clip."
    )
    parser.add_argument(
        "--lock-text",
        default=False,
        action='store_true',
        help="Lock full text tower by disabling gradients.",
    )
    parser.add_argument(
        "--lock-text-unlocked-layers",
        type=int,
        default=0,
        help="Leave last n image tower layer groups unlocked.",
    )
    parser.add_argument(
        "--lock-text-freeze-layer-norm",
        default=False,
        action='store_true',
        help="Freeze BatchNorm running stats in image tower for any locked layers.",
    )
    parser.add_argument(
        "--log-every-n-steps",
        type=int,
        default=100,
        help="Log every n steps to tensorboard/console/wandb.",
    )
    parser.add_argument(
        "--coca-caption-loss-weight",
        type=float,
        default=2.0,
        help="Weight assigned to caption loss in CoCa."
    )
    parser.add_argument(
        "--coca-contrastive-loss-weight",
        type=float,
        default=1.0,
        help="Weight assigned to contrastive loss when training CoCa."
    )
    parser.add_argument(
        "--remote-sync",
        type=str,
        default=None,
        help="Optinoally sync with a remote path specified by this arg",
    )
    parser.add_argument(
        "--remote-sync-frequency",
        type=int,
        default=300,
        help="How frequently to sync to a remote directly if --remote-sync is not None.",
    )
    parser.add_argument(
        "--remote-sync-protocol",
        choices=["s3", "fsspec"],
        default="s3",
        help="How to do the remote sync backup if --remote-sync is not None.",
    )
    parser.add_argument(
        "--delete-previous-checkpoint",
        default=False,
        action="store_true",
        help="If true, delete previous checkpoint after storing a new one."
    )
    parser.add_argument(
        "--distill-model",
        default=None,
        help='Which model arch to distill from, if any.'
    )
    parser.add_argument(
        "--distill-pretrained",
        default=None,
        help='Which pre-trained weights to distill from, if any.'
    )
    return parser


def add_custom_args(parser):
    parser.add_argument(
        "--label-ratio",
        type=float,
        default=0.1,
        help="Subset ratio for paired data.",
    )
    parser.add_argument(
        "--method",
        type=str,
        default="base",
        help="Method for training (base, ours).",
    )
    parser.add_argument(
        "--keyword-path",
        type=str,
        default=None,
        help="Path for keyword candidate set",
    )

    parser.add_argument(
        "--dev",
        type=int,
        default=0,
        help="Device to run the code.",
    )

    parser.add_argument(
        "--peft",
        default=False,
        action='store_true',
        help="Whether to use peft strategy.",
    )

    parser.add_argument(
        "--VPT",
        default=False,
        action='store_true',
        help="Whether to use VPT.",
    )

    parser.add_argument(
        "--text-tuning",
        default=False,
        action='store_true',
        help="Whether to tune the text encoder when applying peft.",
    )

    parser.add_argument("--lr-peft", type=float, default=None, help="Learning rate for peft module.")

    parser.add_argument(
        "--loss-type",
        type=str,
        default=None,
        help="Path for keyword candidate set",
    )

    parser.add_argument("--logit-scale-con", type=float, default=10, help="Logit scale for contrastive loss.")

    parser.add_argument("--pseudo-label-type", type=str, default="ot-image", help="Pseudo label type for captions.")
    parser.add_argument(
        "--text-prompt",
        default=False,
        action='store_true',
        help="Whether to apply prompt to the text encoder.",
    )
    parser.add_argument("--lr-add", type=float, default=5e-5, help="Learning rate for additional module.")
    parser.add_argument(
        "--rankmode",
        type=int,
        default=0,
        help="Rank mode for rank loss.",
    )
    parser.add_argument("--lam-rank", type=float, default=1.0, help="Regularization weight for rank loss.")
    parser.add_argument("--lambda-val", type=float, default=1.0, help="Lambda_val for rank loss.")

    parser.add_argument("--wcon", type=float, default=1.0, help="Lambda_val for rank loss.")
    parser.add_argument(
        "--det",
        default=False,
        action='store_true',
        help="Whether to apply prompt to the text encoder.",
    )
    parser.add_argument(
        "--steps", type=int, default=None, help="Number of steps for each epoch."
    )

    parser.add_argument("--cylam1", type=float, default=1.0, help="Lambda weight for cyclip loss.")
    parser.add_argument("--cylam2", type=float, default=1.0, help="Lambda weight for cyclip loss.")
    parser.add_argument("--cylam3", type=float, default=1.0, help="Lambda weight for cyclip loss.")
    parser.add_argument("--cylam4", type=float, default=1.0, help="Lambda weight for cyclip loss.")
    parser.add_argument("--cylam5", type=float, default=1.0, help="Lambda weight for cyclip loss.")
    parser.add_argument("--cylam6", type=float, default=1.0, help="Lambda weight for cyclip loss.")
    parser.add_argument("--cylam7", type=float, default=1.0, help="Lambda weight for cyclip loss.")

    parser.add_argument("--Tr", type=float, default=1.0, help="Temperature parameter for exp in rank loss.")
    parser.add_argument("--ranklam", type=float, default=0.1, help="Lambda weight for rank loss.")
    parser.add_argument("--Tcache", type=float, default=1.0, help="Temperature parameter for linear cache.")
    parser.add_argument(
        "--cache_grad",
        default=False,
        action='store_true',
        help="Whether to train cache.",
    )
    parser.add_argument(
        "--qlen", type=int, default=None, help="The length of queue for mix alignment."
    )

    parser.add_argument("--yake-v", type=str, default="yake_0.05", help="YAKE keywords version.")

    parser.add_argument(
        "--mae",
        default=False,
        action='store_true',
        help="Whether to apply MAE pretrain.",
    )
    parser.add_argument("--th", type=float, default=0.4, help="Threshold for filtering.")

    parser.add_argument("--smin", type=float, default=-0.5, help="Min val for scaling.")
    parser.add_argument("--smax", type=float, default=0.5, help="Max val for scaling.")

    parser.add_argument("--ls-nouns", type=float, default=None, help="Logit scale for nouns training.")

    parser.add_argument("--nouns-k", type=float, default=10.0, help="For scaling.")
    parser.add_argument("--nouns-b", type=float, default=-5.0, help="For scaling.")

    parser.add_argument(
        "--save_ckpt",
        default=False,
        action='store_true',
        help="Whether to save the checkpont at last epoch.",
    )

    parser.add_argument(
        "--cmin", type=int, default=5, help="Minimum count for selected nouns."
    )
    parser.add_argument(
        "--topk", type=int, default=3, help="Topk for sim."
    )

    parser.add_argument("--resume-path", type=str, default=None, help="Resume path for next stage.")
    parser.add_argument("--resume-mapper", type=str, default=None, help="Resume path for mapper.")
    parser.add_argument(
        "--stage", type=int, default=0, help="Stage for training."
    )
    parser.add_argument(
        "--hdim", type=int, default=256, help="Hidden dim for nouns2text."
    )
    parser.add_argument(
        "--maskmode", type=int, default=0, help="Mask mode."
    )

    parser.add_argument(
        "--lossmode", type=int, default=0, help="Loss mode."
    )
    parser.add_argument(
        "--mu", type=int, default=1, help="mu for unlabeled data size."
    )
    parser.add_argument("--alpha", type=float, default=0.5, help="Alpha for mixup features.")
    parser.add_argument(
        "--plan", type=int, default=0, help="Plan for training."
    )
    parser.add_argument(
        "--ksel", type=int, default=3, help="K for selecting pseu."
    )
    parser.add_argument(
        "--num_words", type=int, default=5, help="Number of words for stage2."
    )
    parser.add_argument("--pkname", type=str, default=None, help="Saving name for pickle dump.")
    parser.add_argument("--pkname1", type=str, default=None, help="Saving name for pickle dump.")
    parser.add_argument(
        "--aug", type=int, default=0, help="Augment."
    )
    parser.add_argument(
        "--layer", type=int, default=4, help="Augment."
    )
    parser.add_argument(
        "--selidx", type=int, default=-1, help="Selected idx for texts."
    )
    parser.add_argument("--dp", type=float, default=0.1, help="Dropout parameter.")
    parser.add_argument("--tau", type=float, default=1.0, help="LA parameter.")
    parser.add_argument(
        "--words", type=int, default=4, help="Number of words."
    )
    parser.add_argument(
        "--selk", type=int, default=10, help="Number of words."
    )

    parser.add_argument(
        "--only_val",
        default=False,
        action='store_true',
        help="Whether only validate the performance.",
    )

    parser.add_argument(
        "--eval_last",
        default=False,
        action='store_true',
        help="Whether only evaluate in the last epoch.",
    )
    parser.add_argument("--abs_data", type=str, default=None, help="Absolute path for datasets.")

    parser.add_argument(
        "--update_nouns",
        default=False,
        action='store_true',
        help="Whether only evaluate in the last epoch.",
    )
    parser.add_argument("--Ts", type=float, default=0.03, help="Temperature parameter for exp in rank loss.")
    parser.add_argument("--wnouns", type=float, default=1.0, help="Temperature parameter for exp in rank loss.")
    parser.add_argument(
        "--qsize", type=int, default=512, help="Number of words."
    )
    parser.add_argument(
        "--qconlen", type=int, default=64, help="Number of words."
    )
    parser.add_argument(
        "--nk", type=int, default=10, help="Number of words."
    )
    parser.add_argument("--pratio", type=float, default=0.3, help="P ratio for selecting captions.")
    parser.add_argument(
        "--data-dir",
        type=str,
        default=None,
        help="Directory for datasets.",
    )
    return parser


def parse_args(args):
    parser = argparse.ArgumentParser()
    parser = add_base_args(parser)
    parser = add_custom_args(parser)
    args = parser.parse_args(args)

    # If some params are not passed, we use the default values based on model name.
    default_params = get_default_params(args.model)
    for name, val in default_params.items():
        if getattr(args, name) is None:
            setattr(args, name, val)

    """Custom setup for convenience"""
    # if args.method != "ours":
    #     args.keyword_path = None

    if args.method != "ours" and args.method != "sclip":
        args.keyword_path = None

    if args.train_data is None and args.name is not None:
        args.resume = f"logs/{args.name}/checkpoints/epoch_latest.pt"
        for model in ["RN50", "ViT-B-32", "ViT-B-16"]:
            if model in args.resume:
                args.model = model
        for seed in [0, 1, 2]:
            if "seed_{}".format(seed) in args.resume:
                args.seed = seed

    return args
