import argparse
from utils.experiment_utils import DEFAULT_HF_TOKEN, DEFAULT_HF_CACHE


def parse_args():
    """
    Parse command line arguments

    Returns:
        args: Parsed argument object
    """
    parser = argparse.ArgumentParser(
        description="LLM Continual Learning - Using SABCD Fine-tuning and SAIM Merging")

    # Basic parameters
    parser.add_argument('--base-model', default='meta-llama/Llama-3.2-1B-Instruct', type=str,
                       help='Base model path')
    parser.add_argument('--save-path', default='./experiments/', type=str,
                        help='Merged model save path')
    parser.add_argument('--cache-dir', default=DEFAULT_HF_CACHE, type=str,
                        help='HuggingFace cache directory')
    parser.add_argument('--token', default=DEFAULT_HF_TOKEN, type=str,
                        help='HuggingFace API token')

    # Dataset parameters - modified to use custom 7 tasks
    parser.add_argument('--task-names', default='C-STANCE,FOMC,MeetingBank,ScienceQA,NumGLUE-cm,NumGLUE-ds,20Minuten', type=str,
                        help='Comma-separated task name list')
    parser.add_argument('--custom-datasets', action='store_true', default=True,
                        help='Use custom datasets (do not use MergeBench datasets)')

    # Fine-tuning parameters
    parser.add_argument('--batch-size', default=8, type=int,
                        help='Fine-tuning batch size')
    parser.add_argument('--learning-rate', default=2e-5, type=float,
                        help='Learning rate')
    parser.add_argument('--epochs', default=5, type=int,
                        help='Training epochs')
    parser.add_argument('--train-from-base', action='store_true', default=False,
                        help='Always fine-tune based on the original pre-trained model, and use the base model as reference when computing task vectors')
    # Add optimizer selection in fine-tuning parameters
    parser.add_argument('--optimizer', default='sabcd', type=str,
                        choices=['sabcd', 'adam', 'sam'],
                        help='Choose optimizer type: sabcd (SABCD optimizer), adam (standard AdamW optimizer), sam (SAM optimizer)')
    parser.add_argument('--selection-percent', default=0.3, type=float,
                        help='SABCD selection parameter percentage')
    parser.add_argument('--rho', default=0.05, type=float,
                        help='SABCD perturbation radius')
    parser.add_argument('--adaptive', default=False, action='store_true',
                        help='Use adaptive')

    # Merging parameters
    parser.add_argument('--task-vector-from-base', default=False, action='store_true',
                        help='Compute task vector from base model (True: fine-tuned - base model, False: fine-tuned - current merged model)')
    parser.add_argument('--scaling-coef', default=1.0, type=float,
                        help='Merge scaling coefficient')
    parser.add_argument('--use-default-scaling', action='store_true', default=False,
                    help='Use algorithm default scaling coefficient, ignore --scaling-coef parameter value')
    parser.add_argument('--merge-method', default='SAIM', type=str,
                        choices=['SAIM', 'task_arithmetic', 'magmax', 
                                'ties_merge', 'dare', 'swa'],
                        help='Merge method: SAIM, task_arithmetic, magmax, ties_merge, dare or swa')

    # Evaluation parameters
    parser.add_argument('--evaluate-at-end', action='store_true', default=False,
                    help='Whether to evaluate after all tasks are completed, default False')
    
    # Continual learning parameters
    parser.add_argument('--start-task', default=0, type=int,
                        help='Start task index (0-6)')
    parser.add_argument('--end-task', default=6, type=int,
                        help='End task index (0-6)')
    parser.add_argument('--continue-experiment', action='store_true',
                        help='Continue previous experiment')
    parser.add_argument('--prev-experiment-dir', default=None, type=str,
                        help='Previous experiment directory')

    return parser.parse_args()

def parse_merge_args():
    """Parse command line arguments related to MergeBench"""
    parser = argparse.ArgumentParser(
        description="LLM Continual Merging - Using Pre-trained Fine-tuned Models and SAIM Merging")

    # Basic parameters
    parser.add_argument('--base-model', default='meta-llama/Llama-3.1-3B', type=str,
                        help='Base model path')

    parser.add_argument('--save-path', default='./merge_experiments/', type=str,
                        help='Merged model save path')
    parser.add_argument('--cache-dir', default=DEFAULT_HF_CACHE, type=str,
                        help='HuggingFace cache directory')
    parser.add_argument('--token', default=DEFAULT_HF_TOKEN, type=str,
                        help='HuggingFace API token')

    # Fine-tuned model parameters
    parser.add_argument('--finetuned-model-prefix', default='MergeBench/Llama-3.1-3B', type=str,
                        help='Fine-tuned model prefix')

    # Merging parameters
    parser.add_argument('--scaling-coef', default=1.0, type=float,
                        help='Merge scaling coefficient')
    parser.add_argument('--use-default-scaling', action='store_true',
                        help='Use algorithm default scaling coefficient, ignore --scaling-coef parameter value')

    # Task parameters
    parser.add_argument('--start-task', default=0, type=int,
                        help='Start task index (0-4)')
    parser.add_argument('--end-task', default=3, type=int,
                        help='End task index (0-4)')

    # Continue experiment parameters
    parser.add_argument('--continue-experiment', action='store_true',
                        help='Continue previous experiment')
    parser.add_argument('--prev-experiment-dir', default=None, type=str,
                        help='Previous experiment directory')

    # Only save final model option
    parser.add_argument('--final-model-only', action='store_true', default=True,
                        help='Only save the final merged model and evaluate, do not save intermediate models')

    # Merge method selection
    parser.add_argument('--merge-method', default='SAIM', type=str,
                        choices=['SAIM', 'task_arithmetic',
                                 'magmax', 'ties_merge', 'dare', 'swa'],
                        help='Merge method: SAIM, task_arithmetic, magmax, ties_merge, dare or swa')

    return parser.parse_args()