from fastargs import Param, Section
from fastargs.validation import And, OneOf, Or

def get_current_params() -> None:
    """Define the various parameters and their constraints with fastargs."""
    Section("model_params", "model details").params(
        model_name=Param(str, "model_choice", default="resnet18", required=True),
        conv_type=Param(And(str, OneOf(["ConvMask", "LinearMask", "STRConv", "ConvMaskMW", "LinearMaskMW", "ConvMaskScaled"])), required=True),
        tf_pretrained = Param(bool, "should the transformer model be pretrained or not.", default=False),
    )

    Section("dataset", "dataset configuration").params(
        dataset_name=Param(
            And(str, OneOf(["CIFAR10", "CIFAR100", "ImageNet", "OxfordPets", "Flowers102"])),
            "Name of dataset",
            required=True,
        ),
        use_ffcv=Param(bool, "use FFCV distributed", default=True),
        batch_size=Param(int, "batch size", default=512),
        num_classes=Param(int, "num classes", default=1000),
        num_workers=Param(int, "num_workers", default=32),
        data_root=Param(str, "path to betons", required=True),
        subsample_frac=Param(float, "fraction of dataset to use in subsampled version", default=0.5),
        criterion=Param(And(str, OneOf(["CrossEntropyLoss", "LabelSmoothingLoss"])), "Criterion to use for loss function", default='CrossEntropyLoss')
    )

    Section("prune_params", "pruning configuration").params(
        prune_rate=Param(float, "percentage of parameters to remove", default=0.2),
        er_init=Param(float, "sparse init percentage/target", default=0.2),
        er_method=Param(
            And(str, OneOf(["percentile_window", "rel_grad_uniform", "grad_mask_train_small_rel_grad_uniform", "grad_mask_train_small_and_large", "grad_mask_train_small_rel_grad", "rel_grad", "grad_mask_train_mid",
            "uniform", "uniform_mag", "uniform_largest", "er_erk", "er_balanced", "synflow", "snip", "mag", "largest", "just dont",
            "load_sign_and_mask", "load_mag_and_mask", "load_weight_and_mask", "load_mw_sign_and_mask", "load_only_mask", "anneal_balanced", "load_dense_and_prune_mag", "load_dense_and_prune_mag_mw", "load_mw_trained_dense_and_prune_mag_mw"])),
            default="just dont",
        ),
        prune_method=Param(
            And(
                str, OneOf(["random_erk", "random_balanced", "synflow", "snip", "mag", "just dont"])
            ),
            default='just dont',
        ),
        num_levels=Param(int, "number of pruning levels", required=True),
        update_sign_init=Param(bool, "update sign based on the gradient at init", default=False),
        update_sign_every_level=Param(bool, "update sign based on the gradient after level", default=False),
        init_type=Param(str, "init type for loading mask in mask ablation", default="warmup"),
        load_level=Param(int, "level of sparsity to be loaded from target dir", default=0),
        target_expt=Param(str, "name of the experiment directory", default="./"),
        target_dir=Param(str, "target directory where experiments are stored", default="./"),
        load_sign_at=Param(Or(str, int), "the checkpoint used for loading the sign", default="./"),
        dst = Param(bool, "perform DST, with base method RiGL", default=False),
        dst_method = Param(str, "kind of DST method", default='rigl'),
        dst_every = Param(int, "perform DST every n steps", default=100),
        dst_grad_window = Param(int, "steps to accumulate the gradient", default=100),
        dst_reinit = Param(And(str, OneOf(["zero", "same", "flip"])), "sign of param when reactivated with RiGL", default='zero'),
        acdc = Param(bool, "Train with ACDC, alternating dense and compressed training within a level.", default=False),
        load_only_warmup_sign=Param(bool, "if init_type is warmup, then if we want to load only the signs", default=False),
        str_init_val=Param(float, "the s_init value for STR with continuous sparsification", default=-12800),
        anneal_scale_floor = Param(float, "anneal mask, floor", default=0),
        anneal_scale_ceil = Param(float, "anneal mask, ceiling", default=0.1),
        mw_l1_floor = Param(float, "l1 for mw mask, floor", default=5e-6),
        mw_l1_ceil = Param(float, "l1 for mw mask, floor", default=5e-6),
        mw_wd_floor = Param(float, "weight decay for mw mask, floor", default=5e-6),
        mw_wd_ceil = Param(float, "weight decay for mw mask, ceiling", default=5e-5),
        rescale_mw = Param(bool, "rescales the mw training every few epochs.", default=False),
        rescale_till = Param(float, "determines till what point of training do we rescaling mw, in a freq of 10.", default=0.5),
        rescale_every = Param(int, "determines ferquency in epochs to do rescaling mw.", default=10),
        start_percentile = Param(float, "starting percentile for creating a mask window", default=0),
        train_first_and_last_grad_mask = Param(bool, "make the first and last layer trainable in grad masking", default=False),
        acdc_prune_method = Param(And(str, OneOf(["grad_mask_train_small_and_large", "mag", "rel_grad"])), "method to prune the model in AC/DC", default="mag"),
        small_weight_frac = Param(float, "fraction of smallest parameters to choose in small + large", default=0.5),
    )

    Section("experiment_params", "parameters to train model").params(
        seed=Param(int, "seed", default=0),
        base_dir=Param(str, "base directory", required=True, default="./experiments"),
        epochs_per_level=Param(int, "number of epochs per level", required=True),
        training_type=Param(And(str, OneOf(["imp", "wr", "lrr"])), required=True),
        resume_level=Param(
            int, "level to resume from -- 0 if starting afresh", default=0
        ),
        resume_expt_name=Param(str, "resume path"),
        wandb_project=Param(str, "name of the wandb project we want to run", required=True),
        expt_name=Param(str, "name of the experiment we want to run", required=True),
        num_cycles=Param(int, "number of cyclic repetition of the LR schedule in one cycle", default=1),
        training_precision = Param(And(str, OneOf(['bfloat16', 'float32'])), default='bfloat16'),
        compute_eigenvals = Param(bool, "compute the eigenvalues of the hessian using the eigenthings library", default=False),
        analysis_dir=Param(str, "name of the directory to be analyzed", required=False),
        dense_grad=Param(bool, "if we want to use dense gradients or sparse gradients during training", default=False),
        imagenet_transfer=Param(bool, "if we want to transfer the model from imagenet to a smaller dataset", default=False)
    )

    Section("optimizer", "data related stuff").params(
        optim_type=Param(str, "type of optimizer", default="SGD"),
        lr=Param(float, "learning rate", required=True),
        momentum=Param(float, "momentum", default=0.9),
        weight_decay=Param(float, "weight decay", default=1e-4),
        warmup_steps=Param(int, "warmup length", default=10),
        cooldown_steps=Param(int, 'cooldown steps', default=10),
        mask_gradient_update=Param(bool, 'Mask some gradients for an update', default=False),
        update_grad_mask_every=Param(int, 'Update the grad mask every n epochs', default=10),
        scheduler_type=Param(
            And(
                str,
                OneOf(
                    [
                        "MultiStepLRWarmup",
                        "ImageNetLRDropsWarmup",
                        "CosineLRWarmup",
                        "TriangularSchedule",
                        "ScheduleFree",
                        "TrapezoidalSchedule",
                        "OneCycleLR"
                    ]
                ),
            ),
            required=True,
        ),
        lr_min=Param(float, "minimum learning rate for cosine", default=0.01),
        use_sam=Param(bool, "if SAM should be used for optimization", default=False),
        fixed_sign_optim=Param(bool, "let your optimizer train, but the signs will be fixed, only mag is updated", default=False),
        ham_optim=Param(bool, "use the HAM optimizer", default=False)
    )

    Section("dist_params", "distributed parameters").params(
        distributed=Param(bool, "use distributed training", default=True),
        address=Param(str, "default address", default="localhost"),
        port=Param(int, "default port", default=12350),
    )



#
