# co_retriever/arguments.py

from dataclasses import dataclass, field
from typing import Optional
from transformers import TrainingArguments


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """
    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    reference_model_name_or_path: Optional[str] = field(
        default="meta-llama/Llama-3.2-1B",
        metadata={"help": "Path to the pretrained reference LLM."},
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    retriever_tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Separate tokenizer for the retriever if different from model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}
    )

    # UPDATED: Added the core hyperparameter for balancing the two losses
    contrastive_loss_weight: float = field(
        default=0.5,
        metadata={"help": "The weight (alpha) for the contrastive loss. The Revela loss weight will be (1 - alpha)."}
    )

    pooling: str = field(
        default='mean',
        metadata={"help": "Pooling method for the retriever encoder. E.g., 'cls' or 'mean'."}
    )
    normalize: bool = field(
        default=False,
        metadata={"help": "Normalize query and passage representations."}
    )
    temperature: float = field(
        default=1.0,
        metadata={"help": "Temperature for the contrastive loss softmax."}
    )
    attn_temperature: float = field(
        default=1.0,
        metadata={"help": "Temperature for the Revela attention softmax."}
    )
    exclude_diagonal: bool = field(
        default=True,
        metadata={"help": "Exclude diagonal scores in the Revela attention calculation."}
    )
    disable_v_norm: bool = field(
        default=False,
        metadata={"help": "Custom flag for the reference model (if supported)."}
    )

    # LoRA arguments
    lora: bool = field(default=False, metadata={"help": "Enable LoRA for parameter-efficient fine-tuning."})
    retriever_lora_name_or_path: Optional[str] = field(
        default=None, metadata={"help": "Path to a pretrained LoRA adapter for the retriever."}
    )
    reference_lora_name_or_path: Optional[str] = field(
        default=None, metadata={"help": "Path to a pretrained LoRA adapter for the reference model."}
    )
    reference_training: bool = field(
        default=True, metadata={"help": "Whether to train the reference model (e.g., its LoRA adapters)."}
    )
    freeze_reference: bool = field(
        default=False, metadata={"help": "Completely freeze all weights of the reference model."}
    )
    lora_r: int = field(default=8, metadata={"help": "LoRA attention dimension (rank)."})
    lora_alpha: int = field(default=64, metadata={"help": "LoRA scaling factor."})
    lora_dropout: float = field(default=0.1, metadata={"help": "Dropout probability for LoRA layers."})
    lora_target_modules: str = field(
        default="q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj",
        metadata={"help": "Comma-separated list of module names to apply LoRA to."}
    )

    # Performance arguments
    flash_attention_2: bool = field(
        default=False, metadata={"help": "Enable Flash Attention 2 for training."}
    )
    bnb_4bit: bool = field(
        default=False, metadata={"help": "Enable 4-bit quantization via bitsandbytes."}
    )

@dataclass
class DataArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """
    dataset_name: str = field(
        default='Tevatron/msmarco-passage', metadata={"help": "Hugging Face dataset name or path to a local directory."}
    )
    dataset_config: str = field(
        default=None, metadata={"help": "Hugging Face dataset config, for datasets with multiple configurations."}
    )
    dataset_path: str = field(
        default=None, metadata={"help": "Path to local data files (e.g., a specific .json file)."}
    )
    dataset_split: str = field(default='train', metadata={"help": "The dataset split to use for training."})
    dataset_cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Directory to cache the downloaded dataset."}
    )
    dataset_number_of_shards: int = field(
        default=1, metadata={"help": "For large datasets, shard the dataset into this many parts."}
    )
    dataset_shard_index: int = field(
        default=0, metadata={"help": "If sharding, which shard to use for this training run."}
    )
        
    top_k: int = field(
        default=8, metadata={"help": "top k passages to retrieve"}
    )

    chunk_neg: bool = field(
        default=False,
        metadata={"help": "Enable chunk-level negatives (e.g., treat other chunks from the same doc as negatives)."}
    )
    
    # Core training data settings
    train_group_size: int = field(
        default=8, metadata={"help": "Number of passages per query for training (1 positive + N negatives)."}
    )
    positive_passage_no_shuffle: bool = field(
        default=False, metadata={"help": "Always use the first positive passage."}
    )
    negative_passage_no_shuffle: bool = field(
        default=False, metadata={"help": "Always use the first N-1 negative passages."}
    )

    # Sequence length and prefix settings
    # UPDATED: Added missing query_max_len
    query_max_len: Optional[int] = field(
        default=32,
        metadata={"help": "Maximum sequence length for queries after tokenization."}
    )
    passage_max_len: Optional[int] = field(
        default=128,
        metadata={"help": "Maximum sequence length for passages after tokenization."}
    )
    query_prefix: str = field(default='query: ', metadata={"help": "Prefix to add to queries."})
    passage_prefix: str = field(default='passage: ', metadata={"help": "Prefix to add to passages."})
    append_eos_token: bool = field(
        default=False, metadata={"help": "Append EOS token to retriever inputs (required for some Llama-based retrievers)."}
    )
    
    pad_to_multiple_of: Optional[int] = field(
        default=None,
        metadata={"help": "Pad sequences to a multiple of this value for hardware acceleration."}
    )

    # REMOVED: Obsolete arguments specific to the old LLMEnhancedDataset
    # - bm25_retrieval_file, top_k, update_index_batch_size, various sampling flags, etc.


@dataclass
class TevatronTrainingArguments(TrainingArguments):
    """
    Training arguments specific to Tevatron training scripts.
    """
    warmup_ratio: float = field(default=0.1, metadata={"help": "Linear warmup over warmup_ratio*total_steps."})

    # Gradient Cache settings
    grad_cache: bool = field(default=False, metadata={"help": "Enable gradient caching for large batch sizes."})
    gc_q_chunk_size: int = field(default=4)
    gc_p_chunk_size: int = field(default=32)
    
    run_name: str = field(default="co-retriever-run", metadata={"help": "A name for the W&B run."})

    # REMOVED: Obsolete arguments for the incompatible UpdateIndexCallback
    # - update_index_training, update_index_steps