from fastargs import Param, Section
from fastargs.validation import And, OneOf

def get_current_params() -> None:
    """Define the various parameters and their constraints with fastargs."""
    Section("model_params", "model details").params(
        model_name=Param(str, "model_choice", default="resnet18", required=True),
        conv_type=Param(And(str, OneOf(["ConvMask", "LinearMask", "STRConv", "ConvMaskMW"])), required=True),
    )

    Section("dataset", "dataset configuration").params(
        dataset_name=Param(
            And(str, OneOf(["CIFAR10", "CIFAR100", "ImageNet"])),
            "Name of dataset",
            required=True,
        ),
        use_ffcv=Param(bool, "use FFCV distributed", default=True),
        batch_size=Param(int, "batch size", default=512),
        num_classes=Param(int, "num classes", default=1000),
        num_workers=Param(int, "num_workers", default=32),
        data_root=Param(str, "path to betons", required=True),
        subsample_frac=Param(float, "fraction of dataset to use in subsampled version", default=0.5),
        criterion=Param(And(str, OneOf(["CrossEntropyLoss", "LabelSmoothingLoss"])), "Criterion to use for loss function", default='CrossEntropyLoss')
    )

    Section("prune_params", "pruning configuration").params(
        prune_rate=Param(float, "percentage of parameters to remove", default=0.2),
        er_init=Param(float, "sparse init percentage/target", default=0.2),
        er_method=Param(
            And(str, OneOf(["er_erk", "er_balanced", "synflow", "snip", "just dont"])),
            default="just dont",
        ),
        prune_method=Param(
            And(
                str, OneOf(["random_erk", "random_balanced", "synflow", "snip", "mag", "just dont"])
            ),
            default='mag',
        ),
        num_levels=Param(int, "number of pruning levels", required=True),
        update_sign_init=Param(bool, "update sign based on the gradient at init", default=False),
        update_sign_every_level=Param(bool, "update sign based on the gradient after level", default=False),
        init_type=Param(str, "init type for loading mask in mask ablation", default="warmup"),
        load_level=Param(int, "level of sparsity to be loaded from target dir", default=0),
        target_dir=Param(str, "target directory of the experiment to use", default="./"),
        dst = Param(bool, "perform DST, with base method RiGL", default=False),
        acdc = Param(bool, "Train with ACDC, alternating dense and compressed training within a level.", default=False),
        pilot = Param(bool, "Train with PILoT, calculate sparsity at end of level.", default=False),
        structured = Param(bool, "Enforce Structured 2:4 sparsity pattern", default = False),
        rescale = Param(bool, "Train with PILoT and rescale, calculate sparsity at end of level.", default=False),
        frequency=Param(int, "level of sparsity to be loaded from target dir", default=10),
        load_only_warmup_sign=Param(bool, "if init_type is warmup, then if we want to load only the signs", default=False),
        str_init_val=Param(float, "the s_init value for STR with continuous sparsification", default=-12800),

    )

    Section("experiment_params", "parameters to train model").params(
        seed=Param(int, "seed", default=0),
        base_dir=Param(str, "base directory", required=True, default="./experiments"),
        epochs_per_level=Param(int, "number of epochs per level", required=True),
        training_type=Param(And(str, OneOf(["imp", "wr", "lrr"])), required=True),
        resume_level=Param(
            int, "level to resume from -- 0 if starting afresh", default=0
        ),
        resume_expt_name=Param(str, "resume path"),
        wandb_project=Param(str, "name of the wandb project we want to run", required=True),
        expt_name=Param(str, "name of the experiment we want to run", required=True),
        num_cycles=Param(int, "number of cyclic repetition of the LR schedule in one cycle", default=1),
        training_precision = Param(And(str, OneOf(['bfloat16', 'float32'])), default='bfloat16'),
        compute_eigenvals = Param(bool, "compute the eigenvalues of the hessian using the eigenthings library", default=False),
        analysis_dir=Param(str, "name of the directory to be analyzed", required=False),
    )

    Section("optimizer", "data related stuff").params(
        optim_type=Param(str, "type of optimizer", default="SGD"),
        lr=Param(float, "learning rate", required=True),
        momentum=Param(float, "momentum", default=0.9),
        weight_decay=Param(float, "weight decay", default=1e-4),
        warmup_steps=Param(int, "warmup length", default=10),
        cooldown_steps=Param(int, 'cooldown steps', default=10),
        scheduler_type=Param(
            And(
                str,
                OneOf(
                    [
                        "MultiStepLRWarmup",
                        "ImageNetLRDropsWarmup",
                        "CosineLRWarmup",
                        "TriangularSchedule",
                        "ScheduleFree",
                        "TrapezoidalSchedule",
                        "OneCycleLR"
                    ]
                ),
            ),
            required=True,
        ),
        lr_min=Param(float, "minimum learning rate for cosine", default=0.01),
        use_sam=Param(bool, "if SAM should be used for optimization", default=False),
        depth=Param(int, "Depth of the reparameterization only used for MW class", default =2),
        inbalance=Param(bool, "Initialization for depth params", default =False),
    )

    Section("dist_params", "distributed parameters").params(
        distributed=Param(bool, "use distributed training", default=True),
        address=Param(str, "default address", default="localhost"),
        port=Param(int, "default port", default=12350),
    )
    
    Section("custom_regularization", "Adding custom penalty to the loss").params(
            regularization=Param(str, OneOf(["Total_reg_loss", "l1_reg_loss", "loss_l1_nobn", "None"]), default="None"),
            regularization2=Param(str, OneOf(["loss_l2_nobn", "None"]), default="None"),
            gamma= Param(float, "first regularization strength", default=0),
            gamma2= Param(float, "second regularization strength", default=0),
    )
#
