import argparse
from dataclasses import dataclass
from typing import Optional

from sandbagging_research_sprint.finetuning.peft_finetuning import main

from eliciting_contexts.utils.constants import DEVICE, WANDB_ENTITY


@dataclass(frozen=True)
class SandbaggingConfig:
    # Model parameter
    model_name: str = "google/gemma-2-9b-it"
    device: str = DEVICE

    # Training params
    batch_size: int = 64

    # Weights & Biases settings
    wandb_project: str = "train-passsword"
    wandb_entity: str = WANDB_ENTITY
    wandb_mode: str = "online"

    # Huggingface
    push_to_hub: bool = True
    huggingface_entity: str = "contextmodification"

    # Limit the number of samples
    n_samples: Optional[int] = None

    # Dataset
    dataset_name: str = "contextmodification/password-locked-dataset"


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train a password model")
    parser.add_argument(
        "--dataset_name",
        type=str,
        required=True,
        help="Name of the dataset to use for training",
    )
    parser.add_argument(
        "--wandb_project",
        type=str,
        required=True,
        help="Name of the Weights & Biases project",
    )

    args = parser.parse_args()

    cfg = SandbaggingConfig(
        model_name="google/gemma-2-2b-it",
        batch_size=8,
        dataset_name=args.dataset_name,
        wandb_project=args.wandb_project,
    )
    main(cfg)
