"""
Configuration classes for the reasoning framework.

This module provides configuration classes that can load from YAML files
and validate settings for different components of the framework.
"""

import os
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
from pathlib import Path


@dataclass
class ModelConfig:
    """Configuration for model parameters."""
    
    # Required fields first (no defaults)
    vlm_model_name: str
    reasoner_model_name: str
    
    # Optional fields second (with defaults)  
    vlm_checkpoint: Optional[str] = None
    vlm_trust_remote_code: bool = True
    reasoner_checkpoint: Optional[str] = None
    reasoner_trust_remote_code: bool = True
    
    # Model loading options
    torch_dtype: str = "bfloat16"
    attn_implementation: str = "flash_attention_2"
    device_map: str = "auto"
    
    # Quantization
    load_in_4bit: bool = False
    load_in_8bit: bool = False
    bnb_4bit_quant_type: str = "nf4"
    bnb_4bit_use_double_quant: bool = True
    bnb_4bit_compute_dtype: str = "bfloat16"


@dataclass
class LoRAConfig:
    """Configuration for LoRA (Low-Rank Adaptation) fine-tuning."""
    
    use_lora: bool = True
    r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.1
    bias: str = "none"
    task_type: str = "CAUSAL_LM"
    
    target_modules: List[str] = field(default_factory=lambda: [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ])
    
    use_rslora: bool = False
    use_dora: bool = False


@dataclass
class TrainingConfig:
    """Basic training configuration class."""
    
    # Basic settings
    num_train_epochs: int = 3
    learning_rate: float = 5e-5
    per_device_train_batch_size: int = 2
    per_device_eval_batch_size: int = 4
    gradient_accumulation_steps: int = 8
    
    # Output settings
    output_dir: str = "./outputs"
    logging_steps: int = 50
    save_steps: int = 1000
    
    # Model settings
    max_seq_length: int = 2048
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary."""
        import dataclasses
        return dataclasses.asdict(self)


@dataclass
class EvaluationConfig:
    """Basic evaluation configuration class."""
    
    # Required fields first (no defaults)
    reasoning_approach: str
    vlm_api_base: str
    reasoner_type: str
    reasoner_api_base: str
    reasoner_api_key: str
    reasoner_model: str
    
    # Optional fields second (with defaults)
    datasets: List[str] = field(default_factory=lambda: ["MathVista_MINI", "MathVerse_MINI"])
    judge: str = "exact_matching"
    nproc: int = 32
    work_dir: str = "outputs/"
    dry_run: bool = True
    limit: int = 32
    seed: int = 42
    deterministic: bool = False
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary."""
        import dataclasses
        return dataclasses.asdict(self)


@dataclass
class LoggingConfig:
    """Configuration for logging and monitoring."""
    
    # Output directory
    output_dir: str = "./outputs"
    run_name: Optional[str] = None
    
    # Logging frequency
    logging_strategy: str = "steps"
    logging_steps: int = 50
    logging_first_step: bool = True
    
    # Save strategy
    save_strategy: str = "steps"
    save_steps: int = 1000
    save_total_limit: int = 3
    save_only_model: bool = False
    
    # Experiment tracking
    report_to: List[str] = field(default_factory=lambda: ["tensorboard"])
    
    # MLflow settings
    mlflow_experiment_name: str = "adaptive-vlm-reasoning"
    mlflow_tracking_uri: str = "./mlruns"
    
    # Weights & Biases settings
    wandb_project: str = "adaptive-vlm-reasoning"
    wandb_entity: Optional[str] = None
    wandb_tags: List[str] = field(default_factory=lambda: ["vlm-training", "reasoning"])


@dataclass 
class HardwareConfig:
    """Configuration for hardware and performance."""
    
    # Multi-GPU settings
    ddp_backend: Optional[str] = None
    ddp_find_unused_parameters: bool = False
    ddp_broadcast_buffers: Optional[bool] = None
    
    # DeepSpeed integration
    deepspeed: Optional[str] = None
    
    # FSDP settings
    fsdp: List[str] = field(default_factory=list)
    fsdp_min_num_params: int = 0
    
    # Memory optimization
    dataloader_persistent_workers: bool = True
    skip_memory_metrics: bool = False


@dataclass
class DataConfig:
    """Configuration for data processing."""
    
    # Dataset paths
    train_data_path: Optional[str] = None
    eval_data_path: Optional[str] = None
    
    # Data processing
    preprocessing_num_workers: int = 4
    max_samples: Optional[int] = None
    
    # Data format
    format_type: str = "chatml"
    
    # Image processing
    include_images: bool = True
    image_processor: Optional[str] = None
    
    # Text processing
    tokenizer_kwargs: Dict[str, Any] = field(default_factory=lambda: {
        "padding_side": "right",
        "truncation_side": "right", 
        "add_special_tokens": True
    })


@dataclass
class AdaptiveConfig:
    """Configuration for adaptive reasoning parameters."""
    
    max_iterations: int = 7
    enable_verification: bool = True
    verification_threshold: float = 0.8
    confidence_threshold: float = 0.9
    early_stopping_patience: int = 2
    
    # Iteration control
    min_iterations: int = 1
    iteration_timeout: float = 120.0
    
    # Verification settings
    verification_strategy: str = "self_consistency"  # Options: "self_consistency", "external"
    verification_samples: int = 3

    # Confidence estimation experiment flags
    enable_vlm_confidence: bool = False
    use_confidence_in_reasoner: bool = False


@dataclass
class TwoStageConfig:
    """Configuration for two-stage reasoning parameters."""
    
    # Two-stage scaffold specific settings
    enable_verification: bool = False
    
    # Confidence estimation experiment flags
    enable_vlm_confidence: bool = False
    use_confidence_in_reasoner: bool = False


@dataclass
class APIConfig:
    """Configuration for API connections."""
    
    # Required fields first (no defaults)
    vlm_api_base: str
    reasoner_api_base: str
    reasoner_model: str
    reasoner_type: str  # Options: "deepseek", "openai", "custom_api"
    
    # Optional fields second (with defaults)
    vlm_api_key: str = "EMPTY"
    vlm_timeout: float = 900.0
    reasoner_api_key: str = "EMPTY"  
    reasoner_timeout: float = 900.0


@dataclass
class ReproducibilityConfig:
    """Configuration for reproducibility settings."""
    
    seed: int = 42
    deterministic: bool = False
    
    # Environment variables for deterministic behavior
    cuda_deterministic: bool = False
    cublas_workspace_config: str = ":4096:8"
    pythonhashseed: Optional[str] = None


@dataclass
class FrameworkConfig:
    """Main configuration class that combines all sub-configurations."""
    
    # Project metadata
    project_name: str = "adaptive-vlm-reasoning"
    experiment_name: str = "default_experiment"
    description: str = "Reasoning framework configuration"
    version: str = "1.0.0"
    
    # Git tracking
    git_commit: Optional[str] = None
    git_branch: Optional[str] = None
    
    # Framework selection
    framework: Optional[str] = None  # Options: "trl", "unsloth", None (auto-detect)
    
    # Sub-configurations
    model: ModelConfig = field(default_factory=ModelConfig)
    lora: LoRAConfig = field(default_factory=LoRAConfig)
    training: TrainingConfig = field(default_factory=TrainingConfig)
    evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)
    logging: LoggingConfig = field(default_factory=LoggingConfig)
    hardware: HardwareConfig = field(default_factory=HardwareConfig)
    data: DataConfig = field(default_factory=DataConfig)
    adaptive: AdaptiveConfig = field(default_factory=AdaptiveConfig)
    two_stage: TwoStageConfig = field(default_factory=TwoStageConfig)
    api: APIConfig = field(default_factory=APIConfig)
    reproducibility: ReproducibilityConfig = field(default_factory=ReproducibilityConfig)
    
    # Advanced settings
    advanced: Dict[str, Any] = field(default_factory=dict)
    
    def validate(self) -> List[str]:
        """Validate configuration and return list of validation errors."""
        errors = []
        
        # Check required fields based on context
        if self.data.train_data_path is None:
            errors.append("data.train_data_path is required for training")
        
        if self.model.vlm_model_name is None:
            errors.append("model.vlm_model_name is required")
        
        # Check for conflicting settings
        if self.training.fp16 and self.training.bf16:
            errors.append("Cannot enable both fp16 and bf16")
        
        if self.lora.use_lora and not self.lora.target_modules:
            errors.append("target_modules must be specified when using LoRA")
        
        # Validate numeric ranges
        if self.training.learning_rate <= 0:
            errors.append("learning_rate must be positive")
        
        if self.lora.r <= 0:
            errors.append("LoRA rank (r) must be positive")
        
        return errors
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert configuration to dictionary."""
        import dataclasses
        return dataclasses.asdict(self)
    
    @classmethod
    def from_dict(cls, config_dict: Dict[str, Any]) -> 'FrameworkConfig':
        """Create configuration from dictionary."""
        # Handle nested configurations
        config_dict = config_dict.copy()
        
        # Create sub-configs
        for field_name, field_type in cls.__annotations__.items():
            if field_name in config_dict and hasattr(field_type, '__annotations__'):
                # This is a dataclass field
                if field_name in config_dict:
                    config_dict[field_name] = field_type(**config_dict[field_name])
        
        return cls(**config_dict)


def get_config_path(config_name: str, config_type: str = "training") -> Path:
    """Get the path to a configuration file.
    
    Args:
        config_name: Name of the config file (without .yaml extension)
        config_type: Type of config ("training", "evaluation", "models")
    
    Returns:
        Path to the configuration file
    """
    # Get the framework root directory
    current_file = Path(__file__).resolve()
    framework_root = current_file.parent.parent.parent.parent
    
    config_dir = framework_root / "configs" / config_type
    config_file = config_dir / f"{config_name}.yaml"
    
    if not config_file.exists():
        raise FileNotFoundError(f"Configuration file not found: {config_file}")
    
    return config_file


def list_available_configs(config_type: str = "training") -> List[str]:
    """List available configuration files of a given type.
    
    Args:
        config_type: Type of config ("training", "evaluation", "models")
    
    Returns:
        List of available configuration names (without .yaml extension)
    """
    current_file = Path(__file__).resolve()
    framework_root = current_file.parent.parent.parent.parent
    
    config_dir = framework_root / "configs" / config_type
    
    if not config_dir.exists():
        return []
    
    configs = []
    for config_file in config_dir.glob("*.yaml"):
        configs.append(config_file.stem)
    
    return sorted(configs) 