import argparse
import datetime
import time
import warnings
from pathlib import Path
from shutil import rmtree

import torch
import torch.utils.data
import torchvision
import torchvision.transforms
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm

import models
from utils import presets
from utils import utils
from utils.sampler import RASampler
from utils.transforms import get_mixup_cutmix
from utils.using_wandb import init_wandb


def train_one_epoch(
    model,
    criterion,
    optimizer,
    data_loader,
    device,
    epoch,
    args,
    model_ema=None,
    wandb_run=None,
    scaler=None,
    finetuning=False,
    gumbel_lr=None,
):
    model.train()
    if finetuning and hasattr(model, "changed_layers"):
        models.train_phase(model, model.changed_layers)

    metric_logger = utils.MetricLogger(delimiter="  ")
    topk_display = [1, 5] if len(data_loader.dataset.classes) > 5 else [1]

    # calculate current step
    step = epoch * len(data_loader)

    # Iterate over the training dataloader with tqdm progress bar
    for image, target in (
        pbar := tqdm(
            data_loader,
            total=len(data_loader),
            desc="Train",
            unit="batch",
            ncols=100,
        )
    ):
        image, target = image.to(device), target.to(device)
        with torch.amp.autocast(device.type, enabled=scaler is not None):
            output = model(image)
            loss = criterion(output, target)

        assert (
            torch.isfinite(loss).all().item()
        ), f"Loss is NaN or Infinite, get: {loss}"

        optimizer.zero_grad()
        if scaler is not None:
            scaler.scale(loss).backward()
            if args.clip_grad_norm is not None:
                # we should unscale the gradients of optimizer's assigned params if do gradient clipping
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if args.clip_grad_norm is not None:
                nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
            optimizer.step()

        if model_ema and step % args.model_ema_steps == 0:
            model_ema.update_parameters(model)
            if epoch < args.lr_warmup_epochs:
                # Reset ema buffer to keep copying weights during warmup period
                model_ema.n_averaged.fill_(0)

        acc = utils.accuracy(output, target, topk=topk_display)
        for i, k in enumerate(topk_display):
            metric_logger.meters[f"acc{k}"].update(acc[i].item(), n=image.shape[0])

        if gumbel_lr is not None:
            gumbel_lr.step()

        # Update tqdm progress bar after 'print_freq' steps
        if step % args.print_freq == 0:

            # WANDB: Log individual item loss and learning rate
            if wandb_run is not None:
                log_dict = {
                    "train/step": step,
                    "train/loss": loss.item(),
                    "train/lr": optimizer.param_groups[0]["lr"],
                }
                if gumbel_lr is not None:
                    log_dict["train/tau"] = model.tau
                for k in topk_display:
                    log_dict[f"train/acc{k}"] = metric_logger.meters[
                        f"acc{k}"
                    ].global_avg
                wandb_run.log(log_dict)

            log_dict = {
                "loss": f"{loss.item():.4g}",
                "lr": f"{optimizer.param_groups[0]['lr']:.4g}",
            }
            for k in topk_display:
                log_dict[f"acc{k}"] = "{:.2f}".format(
                    metric_logger.meters[f"acc{k}"].global_avg
                )
            pbar.set_postfix(log_dict)
            pbar.update()
        step += 1


def evaluate(
    model, criterion, data_loader, device, epoch, print_freq=100, wandb_run=None
):
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")

    topk_display = [1, 5] if len(data_loader.dataset.classes) > 5 else [1]

    step = epoch * len(data_loader)
    num_processed_samples = 0
    with torch.inference_mode():
        for image, target in (
            pbar := tqdm(
                data_loader,
                total=len(data_loader),
                desc="Test",
                unit="batch",
                ncols=100,
            )
        ):
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(image)
            loss = criterion(output, target)

            metric_logger.update(loss=loss.item())
            acc = utils.accuracy(output, target, topk=topk_display)
            for i, k in enumerate(topk_display):
                metric_logger.meters[f"acc{k}"].update(acc[i].item(), n=image.shape[0])

            # Update tqdm progress bar after 'print_freq' steps
            if step % print_freq == 0:

                # WANDB: Log individual item loss and learning rate
                if wandb_run is not None:
                    log_dict = {"test/step": step, "test/loss": loss.item()}
                    for k in topk_display:
                        log_dict[f"test/acc{k}"] = metric_logger.meters[
                            f"acc{k}"
                        ].global_avg
                    wandb_run.log(log_dict)

                log_dict = {"loss": f"{loss.item():.4g}"}
                for k in topk_display:
                    log_dict[f"acc{k}"] = "{:.2f}".format(
                        metric_logger.meters[f"acc{k}"].global_avg
                    )
                pbar.set_postfix(log_dict)
                pbar.update()

            step += 1
            num_processed_samples += image.shape[0]

    num_processed_samples = utils.reduce_across_processes(num_processed_samples)
    if (
        hasattr(data_loader.dataset, "__len__")
        and len(data_loader.dataset) != num_processed_samples
        and torch.distributed.get_rank() == 0
    ):
        # See FIXME above
        warnings.warn(
            f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
            "samples were used for the validation, which might bias the results. "
            "Try adjusting the batch size and / or the world size. "
            "Setting the world size to 1 is always a safe bet."
        )

    metric_logger.synchronize_between_processes()
    log_out = "\n" + "=" * 25 + f" Epoch={epoch} "
    for k in topk_display:
        log_out += "Acc@{k}={score.global_avg:.2f} ".format(
            k=k, score=metric_logger.meters[f"acc{k}"]
        )
    log_out += "=" * 25
    print(log_out)

    return (
        metric_logger.meters[f"acc{min(topk_display)}"].global_avg,
        metric_logger.meters["loss"].global_avg,
    )


def _get_cache_path(args: argparse.Namespace, train_phase: bool) -> None:
    import hashlib

    assert (
        args.dataset_name is not None
    ), "You must provide name data if can use cache memory"
    value = f"{args.dataset_name}_phase-{'train' if train_phase else 'test'}"
    h = hashlib.sha1(value.encode()).hexdigest()
    cache_dir = (
        Path(args.cache_path)
        if Path(args.cache_path).is_dir()
        else (Path.home() / ".torch" / "datasets")
    )
    args.cache_path = (cache_dir / f"{args.dataset_name}_{h[:10]}.pt").expanduser()


def load_data(args):
    print("Loading data")
    val_resize_size, val_crop_size, train_crop_size = (
        args.val_resize_size,
        args.val_crop_size,
        args.train_crop_size,
    )
    interpolation = InterpolationMode(args.interpolation)

    trn_dir = Path(args.data_path).joinpath("train")
    trn_dir = str(trn_dir) if trn_dir.is_dir() else args.data_path
    val_dir = Path(args.data_path).joinpath("val")
    val_dir = str(val_dir) if val_dir.is_dir() else args.data_path
    data_type = getattr(args, "data_type", "other")

    print("Loading training data")
    st = time.time()
    _get_cache_path(args, train_phase=True)
    if args.cache_dataset and args.cache_path.is_file():
        print(f"Loading dataset_train from {args.cache_path}")
        dataset, _ = torch.load(str(args.cache_path), weights_only=True)
    else:
        if data_type == "cropped":
            print("Loading dataset train with cropped data")
            preprocessing = presets.ClassificationCropped(
                resize_size=train_crop_size,
                interpolation=interpolation,
                use_v2=args.use_v2,
            )
        elif data_type == "full":
            print("Loading dataset train with full data")
            preprocessing = presets.ClassificationFull(
                resize_size=train_crop_size,
                interpolation=interpolation,
                train=True,
                use_v2=args.use_v2,
            )
        else:
            auto_augment_policy = getattr(args, "auto_augment", None)
            random_erase_prob = getattr(args, "random_erase", 0.0)
            ra_magnitude = getattr(args, "ra_magnitude", None)
            augmix_severity = getattr(args, "augmix_severity", None)

            preprocessing = presets.ClassificationPresetTrain(
                crop_size=train_crop_size,
                interpolation=interpolation,
                auto_augment_policy=auto_augment_policy,
                random_erase_prob=random_erase_prob,
                ra_magnitude=ra_magnitude,
                augmix_severity=augmix_severity,
                backend=args.backend,
                use_v2=args.use_v2,
                **(
                    {"mean": args.normalize_img_mean, "std": args.normalize_img_std}
                    if args.dataset_name == "mnist"
                    else {}
                ),
            )

        dataset = torchvision.datasets.ImageFolder(trn_dir, preprocessing)

        if args.cache_dataset:
            print(f"Saving dataset_train to {args.cache_path}")
            args.cache_path.parent.mkdir(parents=True, exist_ok=True)
            utils.save_on_master((dataset, trn_dir), str(args.cache_path))
    print(f"Took {time.time() - st:.4f}")

    print("Loading validation data")
    _get_cache_path(args, train_phase=False)
    if args.cache_dataset and args.cache_path.is_file():
        print(f"Loading dataset_test from {args.cache_path}")
        dataset_test, _ = torch.load(str(args.cache_path), weights_only=True)
    else:
        if args.weights and args.test_only:
            weights = torchvision.models.get_weight(args.weights)
            preprocessing = weights.transforms(antialias=True)
            if args.backend == "tensor":
                preprocessing = torchvision.transforms.Compose(
                    [torchvision.transforms.PILToTensor(), preprocessing]
                )
        else:
            if data_type == "cropped":
                print("Loading dataset test with cropped data")
                preprocessing = presets.ClassificationCropped(
                    resize_size=val_crop_size,
                    interpolation=interpolation,
                    use_v2=args.use_v2,
                )
            elif data_type == "full":
                print("Loading dataset test with full data")
                preprocessing = presets.ClassificationFull(
                    resize_size=val_crop_size,
                    interpolation=interpolation,
                    train=False,
                    use_v2=args.use_v2,
                )
            else:
                preprocessing = presets.ClassificationPresetEval(
                    crop_size=val_crop_size,
                    resize_size=val_resize_size,
                    interpolation=interpolation,
                    backend=args.backend,
                    use_v2=args.use_v2,
                    **(
                        {"mean": args.normalize_img_mean, "std": args.normalize_img_std}
                        if args.dataset_name == "mnist"
                        else {}
                    ),
                )

        dataset_test = torchvision.datasets.ImageFolder(val_dir, preprocessing)

        if args.cache_dataset:
            print(f"Saving dataset_test to {args.cache_path}")
            Path(args.cache_path).parent.mkdir(parents=True, exist_ok=True)
            utils.save_on_master((dataset_test, val_dir), str(args.cache_path))

    print("Creating data loaders")
    if args.distributed:
        if hasattr(args, "ra_sampler") and args.ra_sampler:
            train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
        else:
            train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset_test, shuffle=False
        )
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset, dataset_test, train_sampler, test_sampler


def main(args):
    args.output_dir = Path(args.output_dir) / datetime.datetime.now().strftime(
        "%Y-%m-%d_%H%M%S.%f"
    )

    wandb_run = None
    if not args.test_only:
        # Create dirs that do not exist
        args.output_dir.mkdir(parents=True, exist_ok=True)

        if args.wandb_project is not None:
            # Set wandb_dir and wandb_run_name
            args.wandb_run_name = (
                f"{args.output_dir.name}"
                if args.wandb_run_name is None
                else args.wandb_run_name
            )

            args.wandb_dir = Path(args.wandb_dir) / args.wandb_run_name

            # Create dirs that do not exist
            args.wandb_dir.mkdir(parents=True, exist_ok=True)

            # Initialize WandB
            wandb_run = init_wandb(args)
            for phase_name in ["train", "test", "epoch"]:
                wandb_run.define_metric(f"{phase_name}/step")
                wandb_run.define_metric(
                    f"{phase_name}/*", step_metric=f"{phase_name}/step"
                )
            wandb_run.define_metric("epoch/acc", summary="max")

    if args.distributed:
        utils.init_distributed_mode(args)
    print(utils.print_namespace(args))

    device = torch.device(
        "cuda" if torch.cuda.is_available() and args.device == "cuda" else "cpu"
    )
    print(f"\033[0;1;31mDevice: {device}\033[0m")

    if args.use_deterministic_algorithms:
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)
    else:
        torch.backends.cudnn.benchmark = True

    dataset, dataset_test, train_sampler, test_sampler = load_data(args)

    num_classes = len(dataset.classes)
    mixup_cutmix = get_mixup_cutmix(
        mixup_alpha=args.mixup_alpha,
        cutmix_alpha=args.cutmix_alpha,
        num_classes=num_classes,
        use_v2=args.use_v2,
    )
    if mixup_cutmix is not None:

        def collate_fn(batch):
            return mixup_cutmix(*default_collate(batch))

    else:
        collate_fn = default_collate

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )
    data_loader_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=args.batch_size,
        sampler=test_sampler,
        num_workers=args.workers,
        pin_memory=True,
    )

    print("Creating model")
    model = torchvision.models.get_model(
        args.model,
        weights=args.weights,
        num_classes=num_classes,
        **(
            {"gumbel_dim": args.gumbel_dim, "tau": args.gumbel_tau[0]}
            if args.model.startswith("own_")
            else {}
        ),
    )
    model.to(device)

    if args.finetuning and hasattr(model, "changed_layers"):
        print(
            f"\033[0;1;33mFinetuning the selected layer of the model {model.changed_layers}\033[0m"
        )
        for name, param in model.named_parameters():
            param.requires_grad = name.split(".")[0] in model.changed_layers

        # Filter parameters with requires_grad and calculate their sum
    parameters_to_optimize = [
        param for param in model.parameters() if param.requires_grad
    ]
    total_params = sum(p.numel() for p in parameters_to_optimize)
    print("Total trainable parameters:", total_params)

    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)

    custom_keys_weight_decay = []
    if args.bias_weight_decay is not None:
        custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
    if args.transformer_embedding_decay is not None:
        for key in [
            "class_token",
            "position_embedding",
            "relative_position_bias_table",
        ]:
            custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
    parameters = utils.set_weight_decay(
        model,
        args.weight_decay,
        norm_weight_decay=args.norm_weight_decay,
        custom_keys_weight_decay=(
            custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None
        ),
    )

    opt_name = args.opt.lower()
    if opt_name.startswith("sgd"):
        optimizer = torch.optim.SGD(
            parameters,
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov="nesterov" in opt_name,
        )
    elif opt_name == "rmsprop":
        optimizer = torch.optim.RMSprop(
            parameters,
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            eps=0.0316,
            alpha=0.9,
        )
    elif opt_name == "adamw":
        optimizer = torch.optim.AdamW(
            parameters, lr=args.lr, weight_decay=args.weight_decay
        )
    else:
        raise RuntimeError(
            f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported."
        )

    try:
        scaler = torch.amp.GradScaler(device.type) if args.amp else None
    except AttributeError:
        scaler = torch.cuda.amp.GradScaler() if args.amp else None

    args.lr_scheduler = args.lr_scheduler.lower()
    if args.lr_scheduler == "steplr":
        main_lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma
        )
    elif args.lr_scheduler == "cosineannealinglr":
        main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
        )
    elif args.lr_scheduler == "exponentiallr":
        main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
            optimizer, gamma=args.lr_gamma
        )
    elif args.lr_scheduler == "reducelronplateau":
        main_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode="min",
            factor=args.lr_gamma,
            patience=args.lr_step_size,
            min_lr=args.lr_min,
        )
    else:
        raise RuntimeError(
            f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
            "are supported."
        )

    if args.lr_warmup_epochs > 0:
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer,
                start_factor=args.lr_warmup_decay,
                total_iters=args.lr_warmup_epochs,
            )
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer,
                factor=args.lr_warmup_decay,
                total_iters=args.lr_warmup_epochs,
            )
        else:
            raise RuntimeError(
                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
            )
        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer,
            schedulers=[warmup_lr_scheduler, main_lr_scheduler],
            milestones=[args.lr_warmup_epochs],
        )
    else:
        lr_scheduler = main_lr_scheduler

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    model_ema = None
    if args.model_ema:
        # Decay adjustment that aims to keep the decay independent of other hyper-parameters originally proposed at:
        # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
        #
        # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
        # We consider constant = Dataset_size for a given dataset/setup and omit it. Thus:
        # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
        adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
        alpha = 1.0 - args.model_ema_decay
        alpha = min(1.0, alpha * adjust)
        model_ema = utils.ExponentialMovingAverage(
            model_without_ddp, device=device, decay=1.0 - alpha
        )

    if args.resume:
        checkpoint = torch.load(args.resume, map_location="cpu", weights_only=False)

        missing_keys, unexpected_keys = model_without_ddp.load_state_dict(
            checkpoint["model"], strict=False
        )
        warning_message = ""
        if missing_keys:
            warning_message += f"\033[0;1;33mMissing keys: {missing_keys}\033[0m "
        if unexpected_keys:
            warning_message += f"\033[0;1;36mUnexpected keys: {unexpected_keys}\033[0m"
        if warning_message:
            warnings.warn(warning_message)

        # if not args.test_only:
        #     optimizer.load_state_dict(checkpoint["optimizer"])
        #     lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        # args.start_epoch = checkpoint["epoch"] + 1
        if model_ema:
            model_ema.load_state_dict(checkpoint["model_ema"])
        if scaler:
            scaler.load_state_dict(checkpoint["scaler"])
        print(
            f"Resuming model from file '{args.resume}' that was trained by {checkpoint['epoch'] + 1} epochs."
        )

    if args.test_only:
        # We disable the cudnn benchmarking because it can noticeably affect the accuracy
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        if model_ema:
            evaluate(
                model_ema,
                criterion,
                data_loader_test,
                device=device,
                epoch=0,
                print_freq=args.print_freq,
                wandb_run=None,
            )
        else:
            evaluate(
                model,
                criterion,
                data_loader_test,
                device=device,
                epoch=0,
                print_freq=args.print_freq,
                wandb_run=None,
            )
        return

    gumbel_lr = (
        None
        if args.gumbel_tau[0] == 0 or not args.model.startswith("own_")
        else models.utils.GumbelScheduler(
            model,
            "tau",
            *args.gumbel_tau,
            total_iters=(args.gumbel_range[1] - args.gumbel_range[0])
            * len(data_loader),
            annealing_type=args.gumbel_annealing_strategy,
            last_iter=-args.gumbel_range[0] * len(data_loader),
        )
    )

    print("Start training")
    max_acc = {"value": 0, "epoch": 0}
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(
            model,
            criterion,
            optimizer,
            data_loader,
            device,
            epoch,
            args,
            model_ema,
            wandb_run,
            scaler,
            args.finetuning,
            gumbel_lr,
        )
        acc, loss_val = evaluate(
            model,
            criterion,
            data_loader_test,
            device,
            epoch,
            args.print_freq,
            wandb_run,
        )
        if model_ema:
            acc, loss_val = evaluate(
                model_ema,
                criterion,
                data_loader_test,
                device,
                epoch,
                args.print_freq,
                wandb_run,
            )
        if isinstance(lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            lr_scheduler.step(loss_val)
        else:
            lr_scheduler.step()
        if wandb_run is not None:
            wandb_run.log({"epoch/acc": acc, "epoch/step": epoch})
        if args.output_dir.is_dir():
            checkpoint = {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
                "args": args,
            }
            if model_ema:
                checkpoint["model_ema"] = model_ema.state_dict()
            if scaler:
                checkpoint["scaler"] = scaler.state_dict()
            if acc > max_acc["value"]:
                max_acc["value"] = acc
                utils.save_on_master(
                    checkpoint, str(args.output_dir / f"model_{epoch}.pth")
                )
                if (args.output_dir / f"model_{max_acc['epoch']}.pth").is_file():
                    (args.output_dir / f"model_{max_acc['epoch']}.pth").unlink()
                max_acc["epoch"] = epoch
            utils.save_on_master(checkpoint, str(args.output_dir / "checkpoint.pth"))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(f"Training time {total_time_str}")

    # Finish WandB run
    if wandb_run is not None:
        wandb_run.finish()
        rmtree(args.wandb_dir)
    if args.cache_dataset and args.cache_path.is_file():
        args.cache_path.unlink()


def get_args_parser(add_help=True):
    def parse_floats(input_str):
        floats = list(map(float, input_str.split(",")))
        if len(floats) not in [1, 3]:
            raise argparse.ArgumentTypeError(
                "Input must be a list of 1 or 3 float numbers."
            )
        return floats

    parser = argparse.ArgumentParser(
        description="PyTorch Classification Training", add_help=add_help
    )

    parser.add_argument(
        "--data-path",
        default="/datasets01/imagenet_full_size/061417/",
        type=str,
        help="dataset path",
    )
    parser.add_argument("--model", default="resnet18", type=str, help="model name")
    parser.add_argument(
        "--device",
        default="cuda",
        type=str.lower,
        help="device (Use cuda or cpu Default: cuda)",
    )
    parser.add_argument(
        "-b",
        "--batch-size",
        default=32,
        type=int,
        help="images per gpu, the total batch size is $NGPU x batch_size",
    )
    parser.add_argument(
        "--epochs",
        default=90,
        type=int,
        metavar="N",
        help="number of total epochs to run",
    )
    parser.add_argument(
        "-j",
        "--workers",
        default=16,
        type=int,
        metavar="N",
        help="number of data loading workers (default: 16)",
    )
    parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
    parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
    parser.add_argument(
        "--momentum", default=0.9, type=float, metavar="M", help="momentum"
    )
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument(
        "--norm-weight-decay",
        default=None,
        type=float,
        help="weight decay for Normalization layers (default: None, same value as --wd)",
    )
    parser.add_argument(
        "--bias-weight-decay",
        default=None,
        type=float,
        help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
    )
    parser.add_argument(
        "--transformer-embedding-decay",
        default=None,
        type=float,
        help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
    )
    parser.add_argument(
        "--label-smoothing",
        default=0.0,
        type=float,
        help="label smoothing (default: 0.0)",
        dest="label_smoothing",
    )
    parser.add_argument(
        "--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)"
    )
    parser.add_argument(
        "--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)"
    )
    parser.add_argument(
        "--lr-scheduler",
        default="steplr",
        type=str,
        help="the lr scheduler (default: steplr)",
    )
    parser.add_argument(
        "--lr-warmup-epochs",
        default=0,
        type=int,
        help="the number of epochs to warmup (default: 0)",
    )
    parser.add_argument(
        "--lr-warmup-method",
        default="constant",
        type=str,
        help="the warmup method (default: constant)",
    )
    parser.add_argument(
        "--lr-warmup-decay", default=0.01, type=float, help="the decay for lr"
    )
    parser.add_argument(
        "--lr-step-size",
        default=30,
        type=int,
        help="decrease lr every step-size epochs",
    )
    parser.add_argument(
        "--lr-gamma",
        default=0.1,
        type=float,
        help="decrease lr by a factor of lr-gamma",
    )
    parser.add_argument(
        "--lr-min",
        default=0.0,
        type=float,
        help="minimum lr of lr schedule (default: 0.0)",
    )
    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
    parser.add_argument(
        "--output-dir", default=".", type=str, help="path to save outputs"
    )
    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
    parser.add_argument(
        "--start-epoch", default=0, type=int, metavar="N", help="start epoch"
    )
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
        action="store_true",
    )
    parser.add_argument(
        "--dataset-name",
        default=None,
        type=str.lower,
        help="dataset name",
    )
    parser.add_argument(
        "--cache-path",
        dest="cache_path",
        help="Where cache will be save",
        default="/local/data",
    )
    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
    parser.add_argument(
        "--auto-augment",
        default=None,
        type=str,
        help="auto augment policy (default: None)",
    )
    parser.add_argument(
        "--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy"
    )
    parser.add_argument(
        "--augmix-severity", default=3, type=int, help="severity of augmix policy"
    )
    parser.add_argument(
        "--random-erase",
        default=0.0,
        type=float,
        help="random erasing probability (default: 0.0)",
    )

    # Mixed precision training parameters
    parser.add_argument(
        "--amp",
        action="store_true",
        help="Use torch.cuda.amp for mixed precision training",
    )

    # distributed training parameters
    parser.add_argument(
        "--world-size", default=1, type=int, help="number of distributed processes"
    )
    parser.add_argument(
        "--dist-url",
        default="env://",
        type=str,
        help="url used to set up distributed training",
    )
    parser.add_argument(
        "--model-ema",
        action="store_true",
        help="enable tracking Exponential Moving Average of model parameters",
    )
    parser.add_argument(
        "--model-ema-steps",
        type=int,
        default=32,
        help="the number of iterations that controls how often to update the EMA model (default: 32)",
    )
    parser.add_argument(
        "--model-ema-decay",
        type=float,
        default=0.99998,
        help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
    )
    parser.add_argument(
        "--use-deterministic-algorithms",
        action="store_true",
        help="Forces the use of deterministic algorithms only.",
    )
    parser.add_argument(
        "--interpolation",
        default="bilinear",
        type=str,
        help="the interpolation method (default: bilinear)",
    )
    parser.add_argument(
        "--val-resize-size",
        default=256,
        type=int,
        help="the resize size used for validation (default: 256)",
    )
    parser.add_argument(
        "--val-crop-size",
        default=224,
        type=int,
        help="the central crop size used for validation (default: 224)",
    )
    parser.add_argument(
        "--train-crop-size",
        default=224,
        type=int,
        help="the random crop size used for training (default: 224)",
    )
    parser.add_argument(
        "--clip-grad-norm",
        default=None,
        type=float,
        help="the maximum gradient norm (default None)",
    )
    parser.add_argument(
        "--ra-sampler",
        action="store_true",
        help="whether to use Repeated Augmentation in training",
    )
    parser.add_argument(
        "--ra-reps",
        default=3,
        type=int,
        help="number of repetitions for Repeated Augmentation (default: 3)",
    )
    parser.add_argument(
        "--weights", default=None, type=str, help="the weights enum name to load"
    )
    parser.add_argument(
        "--backend",
        default="PIL",
        type=str.lower,
        help="PIL or tensor - case insensitive",
    )
    parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")

    # Own params
    parser.add_argument(
        "--finetuning", action="store_true", help="Fine tune part parameters of model"
    )
    parser.add_argument(
        "--distributed",
        action="store_true",
        help="Use distributed mode for data parallelism",
    )
    parser.add_argument(
        "--normalize-img-mean",
        type=parse_floats,
        default=None,
        help="Normalization parameter mean (1 or 3 float numbers)",
    )
    parser.add_argument(
        "--normalize-img-std",
        type=parse_floats,
        default=None,
        help="Normalization parameter mean (1 or 3 float numbers)",
    )
    parser.add_argument(
        "--data_type",
        choices=["cropped", "full", "other"],
        default="other",
        help="Transform type of data used",
    )
    # -----------------------------------
    # Parameters for gumbel trick
    # -----------------------------------
    parser.add_argument("--gumbel-dim", default=1, type=int, choices=[1, -1])
    parser.add_argument(
        "--gumbel_tau",
        type=float,
        nargs=2,
        default=(1, 0.2),
    )
    parser.add_argument(
        "--gumbel_range",
        type=int,
        nargs=2,
        default=(20, 90),
        help="Range of working gumbel trick in epoch",
    )
    parser.add_argument(
        "--gumbel_annealing_strategy",
        default="cosine",
        choices=["linear", "constant", "exponential", "cosine"],
    )
    # ----------------------
    # Parameters for W&B
    # ----------------------
    parser.add_argument(
        "--wandb-project",
        default=None,
        type=str,
        help="the name of the project",
    )
    parser.add_argument(
        "--wandb-entity",
        default=None,
        type=str,
        help="an entity is a username or team name where you're sending runs",
    )
    parser.add_argument(
        "--wandb-run-name",
        default=None,
        type=str,
        help="a short display name for this run",
    )
    parser.add_argument(
        "--wandb-dir",
        default=None,
        type=str,
        help="an absolute path to a directory where metadata will be stored",
    )
    return parser


if __name__ == "__main__":
    args = (parser := get_args_parser()).parse_args()

    args.gumbel_tau = tuple(sorted(args.gumbel_tau, reverse=True))
    args.gumbel_range = tuple(sorted(args.gumbel_range))

    # Ensure --gumbel_range contains values smaller then --epochs
    if args.gumbel_range[1] > args.epochs:
        parser.error("--gumbel_range must contain values smaller then --epochs")

    main(args)
