"""Configuration dataclass and argument parsing for token classification training."""

import argparse
from dataclasses import dataclass
from typing import Optional

from calib.common_train_utils import compute_gradient_accumulation_steps


@dataclass
class TrainArgs:
    """Configuration for token classification training."""
    
    # Required arguments
    completion_path: str
    eval_completion_path: str
    output_dir: str
    precomputed_path: Optional[str] = None
    eval_precomputed_path: Optional[str] = None
    
    wandb_project: str = "calibrator"
    
    # Seed
    seed: int = 42
    dataset_seed: int = 42
    
    # Model configuration
    model_name: str = None
    hidden_size: Optional[int] = None
    intermediate_size: Optional[int] = None
    num_attention_heads: Optional[int] = None
    num_hidden_layers: int = 1
    num_key_value_heads: Optional[int] = None
    hidden_act: str = None
    base_model_name: Optional[str] = None
    tokenizer_name: str = None
    base_device: Optional[str] = None  # If None, uses same device as training model
    base_num_layers: Optional[int] = None  # If None, uses default from model config
    dropout: float = 0.0
    group_size: int = 1
    mlp_hidden_size: Optional[int] = None
    architecture: str = "probe"  # probe, summary, direct
    num_summarization_vectors: int = 1
    max_context_len: int = 8192
    increment_position_ids: bool = False
    agent_emb: bool = False
    attn_types: str = ""
    node_features: str = ""
    bin_aggregate: bool = False
    causal_bin_aggregate: bool = False
    no_early_node_features_projection: bool = False
    late_node_features_projection: bool = False
    group_softmax: bool = False
    sum_group_softmax: bool = False
    attend_all_group_softmax: bool = False
    late_group_softmax: bool = False
    late_node_features_projection_norm: bool = False
    sum_bin_aggregate: bool = False

    # Training configuration
    epochs: int = 1
    batch_size: int = 4
    base_batch_size: Optional[int] = None
    effective_batch_size: int = 64
    lr: float = 2e-5
    gating_lr: float = None
    agent_lr: float = None
    eval_steps: Optional[int] = None
    save_steps: float = None
    lr_scheduler_type: str = "cosine"
    max_grad_norm: float = 1.0
    sampler_group_size: Optional[int] = None
    
    # Data configuration
    num_prompts: int = None
    eval_num_prompts: int = None
    balance_difficulty: bool = False
    eval_balance_difficulty: bool = False
    select_difficulty: str = None
    eval_select_difficulty: str = None
    num_difficulty_bins: int = 8
    max_seq_len: Optional[int] = None
    last_rollout_only: bool = False
    last_label_only: bool = False
    last_token_only: bool = False
    paragraph_delimiter_token_id: int = None
    
    # Checkpoint configuration
    resume_from_checkpoint: Optional[str] = None
    load_model_from: Optional[str] = None
    
    # Only inj_roll_truncate mode is supported (no configuration needed)
    
    # Evaluation configuration
    eval_only: bool = False
    eval_on_start: bool = False
    wm_group_size: Optional[int] = None
    wm_type: str = "omni"
    eval_temps: str = "1.0"
    eval_tradeoff: bool = False
    save_pred: bool = False
    
    # Attention backend configuration
    attn_backend: str = "sdpa"
    base_attn_backend: str = "flex_attention"
    
    # Computed fields (set during parsing)
    gradient_accumulation_steps: int = 1


def parse_args() -> TrainArgs:
    """Parse command line arguments into TrainArgs dataclass."""
    parser = argparse.ArgumentParser()
    
    # Required arguments (optional if precomputed paths provided)
    parser.add_argument("--completion_path", default=None)
    parser.add_argument("--eval_completion_path", default=None)
    parser.add_argument("--precomputed_path", type=str, default=None,
                        help="Path to precomputed hidden states file for training data")
    parser.add_argument("--eval_precomputed_path", type=str, default=None,
                        help="Path to precomputed hidden states file for evaluation data")

    # Logging
    parser.add_argument("--output_dir", default="outputs/checkpoints")
    parser.add_argument("--wandb_project", default="calibrator")
    

    # Seed
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--dataset_seed", type=int, default=42)
    
    # Model configuration
    parser.add_argument("--model_name", default=None,
                       help="Name of the training model")
    parser.add_argument("--hidden_size", type=int, default=None,
                       help="Hidden size for the model")
    parser.add_argument("--intermediate_size", type=int, default=None,
                       help="Intermediate size for the model")
    parser.add_argument("--num_attention_heads", type=int, default=None)
    parser.add_argument("--num_hidden_layers", type=int, default=1,
                       help="Number of layers for the model")
    parser.add_argument("--num_key_value_heads", type=int, default=None)
    parser.add_argument("--hidden_act", type=str, default=None,
                       help="Activation function for the model")
    parser.add_argument("--base_model_name", default=None,
                       help="Name of the base model used to generate hidden states. If None, uses model_name")
    parser.add_argument("--tokenizer_name", default=None,
                       help="Name of the tokenizer used to tokenize the data")
    parser.add_argument("--base_device", type=str, default=None,
                       help="Device for base model (e.g., 'cuda:0'). If None, uses same device as training model")
    parser.add_argument("--base_num_layers", type=int, default=None,
                       help="Number of layers to use in the base model. If None, uses the model's default configuration")
    parser.add_argument("--dropout", type=float, default=0.0)
    parser.add_argument("--group_size", type=int, default=1,
                       help="Group size for model")
    parser.add_argument("--architecture", type=str, default="probe",
                       help="Multi-sequence calibrator architecture type")
    parser.add_argument("--mlp_hidden_size", type=int, default=None,
                       help="Hidden size for the mlp of the model")
    parser.add_argument("--num_summarization_vectors", type=int, default=1,
                       help="Number of summarization vectors for summary architecture")
    parser.add_argument("--max_context_len", type=int, default=8192,
                       help="Maximum context length for multi-sequence calibrator")
    parser.add_argument("--increment_position_ids", action="store_true",
                       help="If set, use increment_position_ids for multi-sequence calibrator")
    parser.add_argument("--agent_emb", action="store_true",
                       help="If set, use agent_emb for multi-sequence calibrator")
    parser.add_argument("--attn_types", type=str, default="",
                       help="If set, use attn_types for multi-sequence calibrator. Choices: omni, omni_intra, omni_bin, omni_indiv, causal, causal_intra, causal_bin, causal_indiv")
    parser.add_argument("--node_features", type=str, default="",
                       help="If set, use node_features for multi-sequence calibrator. Choices: log_prob, position_id, bin_fraction")
    parser.add_argument("--bin_aggregate", action="store_true",
                       help="If set, use bin_aggregate for multi-sequence calibrator")
    parser.add_argument("--no_early_node_features_projection", action="store_true",
                       help="If set, use no_early_node_features_projection for multi-sequence calibrator")
    parser.add_argument("--late_node_features_projection", action="store_true",
                       help="If set, use late_node_features_projection for multi-sequence calibrator")
    parser.add_argument("--group_softmax", action="store_true",
                       help="If set, use group_softmax for multi-sequence calibrator")
    parser.add_argument("--sum_group_softmax", action="store_true",
                       help="If set, use sum_group_softmax for multi-sequence calibrator")
    parser.add_argument("--attend_all_group_softmax", action="store_true",
                       help="If set, use attend_all_group_softmax for multi-sequence calibrator")
    parser.add_argument("--late_group_softmax", action="store_true",
                       help="If set, use late_group_softmax for multi-sequence calibrator")
    parser.add_argument("--late_node_features_projection_norm", action="store_true",
                       help="If set, use late_node_features_projection_norm for multi-sequence calibrator")
    parser.add_argument("--sum_bin_aggregate", action="store_true",
                       help="If set, use sum_bin_aggregate for multi-sequence calibrator")
    parser.add_argument("--causal_bin_aggregate", action="store_true",
                       help="If set, use causal_bin_aggregate for multi-sequence calibrator")
                       
    # Training configuration
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--base_batch_size", type=int, default=None)
    parser.add_argument("--effective_batch_size", type=int, default=64)
    parser.add_argument("--lr", type=float, default=2e-5)
    parser.add_argument("--gating_lr", type=float, default=None)
    parser.add_argument("--agent_lr", type=float, default=None)
    parser.add_argument("--eval_steps", type=float, default=None)
    parser.add_argument("--save_steps", type=float, default=None)
    parser.add_argument("--lr_scheduler_type", type=str, default="constant")
    parser.add_argument("--max_grad_norm", type=float, default=1.0)
    parser.add_argument("--sampler_group_size", type=int, default=None,
                       help="Group size for sampler")
    
    # Data configuration
    parser.add_argument("--num_prompts", type=int, default=None, 
                       help="If > 0, restrict dataset to first N prompts")
    parser.add_argument("--eval_num_prompts", type=int, default=None, 
                       help="If > 0, restrict eval dataset to first N prompts")
    parser.add_argument("--balance_difficulty", action="store_true",
                       help="If set, balance difficulty of train and eval datasets")
    parser.add_argument("--eval_balance_difficulty", action="store_true",
                       help="If set, balance difficulty of eval dataset")
    parser.add_argument("--select_difficulty", type=str, default=None,
                       help="If set, select difficulty of train and eval datasets")
    parser.add_argument("--eval_select_difficulty", type=str, default=None,
                       help="If set, select difficulty of eval dataset")
    parser.add_argument("--num_difficulty_bins", type=int, default=8,
                       help="Number of difficulty bins to use for balancing difficulty")
    parser.add_argument("--max_seq_len", type=int, default=None,
                       help="Optional maximum tokenized sequence length (including special tokens). "
                            "If set, inputs will be truncated to this length. "
                            "Defaults to the saved max_model_len from completions_rollouts.py.")
    parser.add_argument("--last_rollout_only", action="store_true",
                       help="If set, inject rollouts and labels only after the last paragraph instead of after every paragraph")
    parser.add_argument("--last_label_only", action="store_true",
                       help="If set, use only the last token of each rollout for training/inference")
    parser.add_argument("--last_token_only", action="store_true",
                       help="If set, use only the last token of each rollout for training/inference")
    parser.add_argument("--paragraph_delimiter_token_id", type=int, default=None,
                       help="Token ID used as paragraph delimiter. If not provided, will attempt to get from saved args.")

    # Checkpoint configuration
    parser.add_argument("--resume_from_checkpoint", type=str, default=None)
    parser.add_argument("--load_model_from", type=str, default=None)

    # Only inj_roll_truncate mode is supported
    
    # Evaluation configuration
    parser.add_argument("--eval_only", action="store_true", 
                       help="Only evaluate the model without training (requires --load_model_from)")
    parser.add_argument("--eval_on_start", action="store_true", 
                       help="Run evaluation at the start of training for debugging purposes")
    parser.add_argument("--wm_group_size", type=int, default=None,
                       help="Group size for wm")
    parser.add_argument("--wm_type", type=str, default="omni", choices=["omni", "causal", "recent"],
                       help="If set, use only the most recent token for wm")
    parser.add_argument("--eval_temps", type=str, default="1.0",
                       help="Comma-separated temperature values for calibration evaluation (e.g., '1.0,2.5,5.0')")
    parser.add_argument("--eval_tradeoff", action="store_true",
                       help="If set, compute tradeoff metrics")
    parser.add_argument("--save_pred", action="store_true",
                       help="If set, save evaluation predictions and dataset")
    
    # Attention backend configuration
    parser.add_argument("--attn_backend", type=str, default="sdpa", 
                       choices=["eager", "flex_attention", "flash_attention_2", "flash_attention_3", "sdpa"],
                       help="Attention backend to use")
    parser.add_argument("--base_attn_backend", type=str, default="flex_attention", 
                       choices=["eager", "flex_attention", "flash_attention_2", "flash_attention_3", "sdpa"],
                       help="Attention backend to use for base model")
    

    args = parser.parse_args()

    # Validate that either completion_path or precomputed_path is provided
    if args.completion_path is None and args.precomputed_path is None:
        raise ValueError("Either --completion_path or --precomputed_path must be provided")
    if args.eval_completion_path is None and args.eval_precomputed_path is None:
        raise ValueError("Either --eval_completion_path or --eval_precomputed_path must be provided")

    # Only inj_roll_truncate mode is supported

    # # Validate eval_only option
    # if args.eval_only and args.load_model_from is None:
    #     raise ValueError("--eval_only requires --load_model_from to specify a model checkpoint to evaluate.")
    
    args.gradient_accumulation_steps = compute_gradient_accumulation_steps(
        args.effective_batch_size, args.batch_size)
    
    # Convert parsed args to dataclass
    return TrainArgs(**vars(args))


