import os
from typing import Any, Dict, List, Literal, Optional, Union

from pydantic import BaseModel, Field, field_validator, model_validator
import argparse



class TrainingConfig(BaseModel):
    class Config:
        extra = "forbid"  # Prevent extra fields not defined in the model

    # Required model and data paths
    model: str = Field(..., description="Hugging Face model ID")
    training_file: str = Field(..., description="File ID of the training dataset")
    test_file: Optional[str] = Field(None, description="File ID of the test dataset")

    # Output model
    finetuned_model_id: str = Field(
        "{org_id}/{model_name}-{job_id}", description="File ID of the finetuned model"
    )

    # Model configuration
    max_seq_length: int = Field(
        2048, description="Maximum sequence length for training"
    )
    load_in_4bit: bool = Field(
        False, description="Whether to load model in 4-bit quantization"
    )

    # Training type configuration
    loss: Literal["grpo", "kl", "ldifs"] = Field(
        ..., description="Loss function / training type"
    )

    # PEFT configuration
    is_peft: bool = Field(True, description="Whether to use PEFT for training")
    full_finetuning: bool = Field(
        False, description="Whether to use full finetuning"
    )
    target_modules: Optional[List[str]] = Field(
        default=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
        ],
        description="Target modules for LoRA",
    )
    lora_bias: Literal["all", "none"] = Field(
        "none", description="Value for FastLanguageModel.get_peft_model(bias=?)"
    )

    # LoRA specific arguments
    r: int = Field(16, description="LoRA attention dimension")
    lora_alpha: int = Field(16, description="LoRA alpha parameter")
    lora_dropout: float = Field(0.0, description="LoRA dropout rate")
    use_rslora: bool = Field(True, description="Whether to use RSLoRA")
    merge_before_push: bool = Field(
        True,
        description="Whether to merge model before pushing to Hub. Only merged models can be used as parent models for further finetunes. Only supported for bf16 models.",
    )
    push_to_private: bool = Field(True, description="Whether to push to private Hub")

    # Training hyperparameters
    epochs: int = Field(1, description="Number of training epochs")
    max_steps: Optional[int] = Field(
        None, description="Maximum number of training steps"
    )
    per_device_train_batch_size: int = Field(
        2, description="Training batch size per device"
    )
    gradient_accumulation_steps: int = Field(
        8, description="Number of gradient accumulation steps"
    )
    warmup_steps: int = Field(5, description="Number of warmup steps")
    learning_rate: Union[float, str] = Field(
        1e-4, description="Learning rate or string expression"
    )
    logging_steps: int = Field(1, description="Number of steps between logging")
    optim: str = Field("adamw_8bit", description="Optimizer to use for training")
    weight_decay: float = Field(0.01, description="Weight decay rate")
    lr_scheduler_type: str = Field("linear", description="Learning rate scheduler type")
    seed: int = Field(3407, description="Random seed for reproducibility")
    beta: float = Field(0.1, description="Beta parameter for DPO/ORPO training")
    save_steps: int = Field(5000, description="Save checkpoint every X steps")
    output_dir: str = Field(
        "./tmp", description="Output directory for training checkpoints"
    )
    train_on_responses_only: bool = Field(
        False, description="Whether to train on responses only"
    )
    grader_type: str = Field(
        "code_correct", description="Which prompt to use for the grader score model"
    )
    reasoning_grader_type: str = Field(
        "none", description="Which prompt to use for the reasoning grader score model"
    )
    reward_coherence: bool = Field(
        True, description="Include extra reward model rewarding coherent code"
    )
    reward_model: str = Field(
        "gpt-4.1-mini", description="OpenAI model used as score grader model"
    )
    print_training: bool = Field(
        False, description="Print input/output to models during training for debugging"
    )
    define_assistant_reasoning: bool = Field(
        False, description="Whether to include assistant prompt in generation"
    )
    evaluate_epoch: int = Field(
        0, description="Set a number to evaluate model and store the model after a fourth epoch"
    )
    num_generations: int = Field(2, description="Hyperparameter")
    rl_max_new_tokens: int = Field(256, description="Hyperparameter")
    max_prompt_length: int = Field(256, description="Hyperparameter")
    rl_temperature: float = Field(1.0, description="Hyperparameter")
    rl_top_p: float = Field(0.95, description="Hyperparameter")
    training_method: str = Field("cross_entropy", description="Hyperparameter")
    ldifs_lambda: float = Field(0.1, description="Hyperparameter")
    num_intermediate_layers: int = Field(5, description="Hyperparameter")
    no_test_split: bool = Field(False, description="Whether to not split the test set")
    
    # Steering configuration
    steering_config: Optional[Dict] = Field(None, description="Steering configuration for projection intervention")
    enable_steering_during_training: bool = Field(False, description="Whether to enable steering during training")

    @model_validator(mode="before")
    def validate_training_file_prefixes(cls, values):
        loss = values.get('loss', 'orpo')
        training_file = values.get('training_file')
        if isinstance(training_file, str):
            if os.path.exists(training_file):
                return values 
        elif isinstance(training_file, list):
            for file in training_file:
                if not os.path.exists(file):
                    raise ValueError(f"Training file {file} does not exist")
        else:
            raise ValueError(f"Training file must be a string or a list of strings, got: {training_file}")
            
        # if loss == 'sft' and not training_file.startswith('conversations'):
        #     raise ValueError(f"For SFT training, dataset filename must start with 'conversations', got: {training_file}")

        if loss in ['dpo', 'orpo'] and not training_file.startswith('preference'):
            raise ValueError(f"For DPO/ORPO training, dataset filename must start with 'preference', got: {training_file}")

        return values
    
    @field_validator("finetuned_model_id")
    def validate_finetuned_model_id(cls, v):
        # if v and model_exists(v):
        #     raise ValueError(f"Model {v} already exists")
        if len(v.split("/")) != 2:
            raise ValueError("Model ID must be in the format 'user/model'")
        org, model = v.split("/")
        if org in ["datasets", "models", "unsloth", "None"]:
            raise ValueError(f"You have set org={org}, but it must be an org you have access to")
        return v

    @field_validator("learning_rate", mode="before")
    def validate_learning_rate(cls, v):
        if isinstance(v, float) and v <= 0:
            raise ValueError("Learning rate must be positive")
        return v

    @field_validator("lora_dropout")
    def validate_dropout(cls, v):
        if not 0 <= v <= 1:
            raise ValueError("Dropout rate must be between 0 and 1")
        return v

    @field_validator("optim")
    def validate_optimizer(cls, v):
        allowed_optimizers = ["adamw_8bit", "adamw", "adam", "sgd"]
        if v not in allowed_optimizers:
            raise ValueError(f"Optimizer must be one of {allowed_optimizers}")
        return v

    @field_validator("lr_scheduler_type")
    def validate_scheduler(cls, v):
        allowed_schedulers = ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
        if v not in allowed_schedulers:
            raise ValueError(f"Scheduler must be one of {allowed_schedulers}")
        return v

