from argparse import ArgumentParser
import math

def add_experiment_args(parser: ArgumentParser) -> None:
    """
    Adds the arguments used by all the models.

    Args:
        parser: the parser instance
    """
    # Two datasets need to be explicitly specified.
    parser.add_argument('--dataset', type=str, 
                        help='Which dataset to perform training on.')
    parser.add_argument('--dataset-type', type=str, required=True,
                        help='Which dataset to perform training on.')
    parser.add_argument('--max-seq-len', type=int, default=512)
    parser.add_argument('--modelwrapper', type=str, required=True,
                        help='Model name, one of the following: MLE, MAP, Deep Ensemble, Batch Ensemble, Laplace LoRA, BLoB LoRA')
    parser.add_argument('--model', type=str, required=True,
                        help='Backbone type, one of the following: roberta-base, roberta-large')
    parser.add_argument('--model-type', type=str, required=True,
                        help='Backbone type, one of the following: roberta-base, roberta-large')
    parser.add_argument('--load-in-8bit', type=bool, default=True, 
                        help='Whether to load the model in 8-bit.')
    
    # Optimization-specfiic arguments
    parser.add_argument('--loss', type=str, default='nll',
                        help='Loss name')
    parser.add_argument('--n-epochs', type=int, default=0,
                        help='number of epochs.')
    parser.add_argument(
        "--max-train-steps",
        type=int,
        default=0,
        help="Total number of training steps to perform. If provided, overrides n-epochs.",
    )
    parser.add_argument(
        "--eval-per-steps",
        type=int,
        default=500,
    )
    parser.add_argument(
        "--early-stop-steps",
        type=int,
        default=0,
    )
    parser.add_argument('--batch-size', type=int,
                        help='Batch size.')
    parser.add_argument('--lr', type=float, default=0.0001,
                        help='Learning rate.')
    parser.add_argument('--opt', type=str, default='adamw',
                        help='Optimizer type.')
    parser.add_argument('--opt-wd', type=float, default=0.,
                        help='optimizer weight decay.')
    parser.add_argument('--warmup-ratio', type=float, default=0,
                        help='warmup ratio.')
    parser.add_argument('--adam-epsilon', type=float, default=1e-06, 
                        help='default adam epsilon.')
    parser.add_argument('--use-slow-tokenizer', action='store_true', 
                        help='Use slow tokenizer.')
    parser.add_argument('--add-space', action='store_true', 
                        help='Add space between the prompt and the input.')
    parser.add_argument('--is_s2s', action='store_true', 
                        help='Whether the model is a sequence-to-sequence model.')
    parser.add_argument(
        "--pad-to-max-length",
        action="store_true",
        help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",
    )
    parser.add_argument('--eval-steps', type=int, default=0,
                        help='set 0 to disable')
    parser.add_argument('--num-bins', type=int, default=15,
                        help='num of bins in ECE computation.')
    parser.add_argument('--load-model-path', type=str, default=None)
    parser.add_argument('--load-checkpoint', action='store_true',
                        help='Whether to load checkpoint.')
    parser.add_argument('--log-path', type=str, default='default')
    parser.add_argument('--lm-head', action='store_true')
    parser.add_argument("--testing_set", type=str, default='val')
    parser.add_argument("--ood-ori-dataset", type=str, default=None)
    
    # LoRA arguments
    parser.add_argument('--lora-r', type=int, default=8)
    parser.add_argument('--lora-alpha', type=int, default=16)
    parser.add_argument('--lora-dropout', type=float, default=0)
    parser.add_argument('--apply-classhead-lora', action='store_true',
                        help='Whether to apply lora on the classhead of model.')
    parser.add_argument('--apply-qkv-head-lora', action='store_true',
                        help= 'Whether to apply lora on the qkv and lm_head of model.')
    parser.add_argument('--apply-lora-qwen', action='store_true', 
                        help='Whether to apply lora on the qkv and lm_head of model.') #"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"
    
    # laplace_LoRA arguments
    parser.add_argument('--laplace-train', action='store_true',
                        help='Whether to apply laplace.')
    parser.add_argument("--laplace_sub", type=str, default='all', help='all or last_layer')
    parser.add_argument("--laplace_n_kfac", type=int, default=10, help='The rank used for the large Kronecker factors')
    parser.add_argument("--laplace_lr_threshold", type=int, default=100, help='The Kronecker factor edge size over which to use the low-rank approximation')
    parser.add_argument("--laplace_prior_var", type=float, default=1.0, help='Initial prior variance (before marginal likelihood-based optimisation)')
    parser.add_argument(
        "--gradient-accumulation-steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument("--laplace_hessian", type=str, default='kron')
    parser.add_argument("--laplace_prior", type=str, default='homo', help='homo, hetero')
    parser.add_argument("--laplace_optim_step", type=int, default=1000)
    parser.add_argument("--laplace_predict", type=str, default='mc_corr_100000', help='probit bridge bridge_norm mc_indep mc_corr')

    parser.add_argument('--laplace-vis', action='store_true',
                        help='Whether to apply laplace.')

    parser.add_argument('--load-laststep-model', action='store_true',
                        help='Whether to load the model in the last training step.')

def add_management_args(parser: ArgumentParser) -> None:
    """
    Arguments for the management of the experiments, e.g., seed, logging, wandb, etc.
    """
    parser.add_argument('--seed', type=int, default=1234,
                        help='The random seed.')
    parser.add_argument('--evaluate', action='store_true',
                        help='Whether to evaluate the model during training.')
    parser.add_argument('--evaluate-uncertainty', action='store_true',
                        help='Whether to evaluate the model uncertainty during training.')
    parser.add_argument('--evaluate-uncertainty-reduction', type=str, default='mean',
                        help='Whether evaluate the mean of the uncertainty.')
    parser.add_argument('--validation', action='store_true',
                        help='Test on the validation set for each epoch.')
    parser.add_argument('--validation-perc', type=float, default=0.1,
                        help='percentage of the validation data.')
    parser.add_argument('--checkpoint', action='store_true',
                        help='Whether checkpoint the model backbone parameters.')
    parser.add_argument('--checkpoint-dic-name', type=str, default='default',
                        help= 'The name of the dictionary to save the checkpoint.')
    parser.add_argument(
        "--checkpoint-name",
        type=str,
        default="default",
        help="Name of the checkpoint file to save. Default is 'default'.",
    ) 
    
    # Arguments Weght & Bias logging tool.
    parser.add_argument('--nowand', action='store_true', help='Inhibit wandb logging')
    parser.add_argument('--wandb-entity', type=str, default='wfz-texas-a-m-university', help='Wandb entity')
    parser.add_argument('--wandb-project', type=str, default='Bayes LoRA', help='Wandb project name')
    parser.add_argument('--wandb-name', type=str, default='', help="Wandb run's name")

    parser.add_argument('--subset-size', type = float, default = -1, help='validation-set ....')
    parser.add_argument('--num_workers', type=int, default=4, help='Number of data loader workers') 
    parser.add_argument('--case-study-vis', action='store_true', help='whether checking case study during OOD setting')
    parser.add_argument('--ood-emb-vis', action='store_true', help='whether exporting representation during OOD setting')
    parser.add_argument('--no_quantization', action='store_true', help='whether load a quantized base model') 
    parser.add_argument('--laplace-ood', action='store_true', help='loading laplace model when OOD evaluation')