"""
Distributed training configuration extensions for 72B VLM training.

This module extends the base TrainingConfig with distributed training parameters
while maintaining compatibility with the existing training framework.
"""

from dataclasses import dataclass, field
from typing import Dict, Any, Optional, List, Union
from .base_trainer import TrainingConfig


@dataclass
class DistributedTrainingConfig(TrainingConfig):
    """
    Extended training configuration for distributed 72B VLM training.
    
    This configuration adds distributed training parameters while maintaining
    full compatibility with the existing training framework.
    """
    
    # Distributed training settings
    distributed_enabled: bool = False
    distributed_strategy: str = "deepspeed"  # "deepspeed", "fsdp", "ddp"
    num_gpus: int = 1
    num_nodes: int = 1
    
    # DeepSpeed configuration
    deepspeed_config_path: Optional[str] = None
    zero_stage: int = 3  # ZeRO optimization stage
    offload_optimizer: bool = True
    offload_param: bool = True
    cpu_offload_use_pin_memory: bool = True
    
    # FSDP configuration  
    fsdp_auto_wrap_policy: str = "TRANSFORMER_BASED_WRAP"
    fsdp_backward_prefetch_policy: str = "BACKWARD_PRE"
    fsdp_forward_prefetch: bool = False
    fsdp_use_orig_params: bool = True
    fsdp_cpu_ram_efficient_loading: bool = True
    
    # 72B-specific model settings
    model_size: str = "72b"  # Model size identifier
    load_in_8bit: bool = False  # Alternative to 4-bit
    torch_compile: bool = False  # Usually disabled for quantized models
    flash_attention: bool = True
    
    # Enhanced quantization settings for 72B
    bnb_4bit_quant_storage: str = "uint8"
    prepare_model_for_kbit_training: bool = True
    
    # Memory optimization for large models
    max_memory_per_gpu: Optional[str] = None  # e.g., "45GB" 
    low_cpu_mem_usage: bool = True
    
    # Enhanced LoRA settings for 72B
    lora_target_modules: Optional[List[str]] = None
    lora_modules_to_save: Optional[List[str]] = None
    use_rslora: bool = False  # RSLoRA for better convergence
    use_dora: bool = False    # DoRA for improved performance
    
    # 72B-specific training optimizations
    optimizer_type: str = "paged_adamw_8bit"  # Memory-efficient optimizer
    scheduler_type: str = "cosine_with_restarts"
    dataloader_pin_memory: bool = False  # Often better for large models
    dataloader_num_workers: int = 2  # Conservative for memory
    remove_unused_columns: bool = False  # Keep image data
    
    # Advanced training settings
    neftune_noise_alpha: Optional[float] = None
    early_stopping_patience: int = 3
    early_stopping_threshold: float = 0.001
    
    # Cluster/SLURM settings
    slurm_job_id: Optional[str] = None
    master_addr: str = "localhost"
    master_port: str = "29500"
    
    # Monitoring and checkpointing for long training
    checkpoint_every_n_steps: int = 100
    keep_last_n_checkpoints: int = 3
    
    def to_deepspeed_config(self) -> Dict[str, Any]:
        """
        Generate DeepSpeed configuration dictionary.
        
        Returns:
            DeepSpeed configuration for 72B training
        """
        config = {
            "fp16": {
                "enabled": not self.mixed_precision  # Use fp16 if not bf16
            },
            "bf16": {
                "enabled": self.mixed_precision  # Prefer bf16 for stability
            },
            "zero_optimization": {
                "stage": self.zero_stage,
                "reduce_bucket_size": 5e8,
                "stage3_prefetch_bucket_size": 5e7,
                "stage3_param_persistence_threshold": 1e6,
                "overlap_comm": True,
                "contiguous_gradients": True,
            },
            "gradient_accumulation_steps": self.gradient_accumulation_steps,
            "gradient_clipping": 1.0,
            "steps_per_print": self.logging_steps,
            "train_batch_size": "auto",
            "train_micro_batch_size_per_gpu": self.batch_size,
            "wall_clock_breakdown": False
        }
        
        # Add optimizer offloading if enabled
        if self.offload_optimizer:
            config["zero_optimization"]["offload_optimizer"] = {
                "device": "cpu",
                "pin_memory": self.cpu_offload_use_pin_memory
            }
        
        if self.offload_param:
            config["zero_optimization"]["offload_param"] = {
                "device": "cpu", 
                "pin_memory": self.cpu_offload_use_pin_memory
            }
        
        return config
    
    def to_fsdp_config(self) -> Dict[str, Any]:
        """
        Generate FSDP configuration dictionary.
        
        Returns:
            FSDP configuration for 72B training
        """
        return {
            "fsdp_auto_wrap_policy": self.fsdp_auto_wrap_policy,
            "fsdp_backward_prefetch_policy": self.fsdp_backward_prefetch_policy,
            "fsdp_forward_prefetch": self.fsdp_forward_prefetch,
            "fsdp_use_orig_params": self.fsdp_use_orig_params,
            "fsdp_cpu_ram_efficient_loading": self.fsdp_cpu_ram_efficient_loading,
            "fsdp_sync_module_states": True,
            "fsdp_state_dict_type": "SHARDED_STATE_DICT",
        }
    
    def get_model_loading_kwargs(self) -> Dict[str, Any]:
        """
        Get model loading arguments for 72B distributed setup.
        
        Returns:
            Dictionary of model loading arguments
        """
        kwargs = {
            "torch_dtype": "bfloat16" if self.mixed_precision else "float16",
            "trust_remote_code": True,
            "low_cpu_mem_usage": self.low_cpu_mem_usage,
        }
        
        # Device map for distributed training
        if self.distributed_enabled:
            # Let DeepSpeed/FSDP handle device placement
            kwargs["device_map"] = None
        else:
            kwargs["device_map"] = "auto"
        
        # Memory constraints
        if self.max_memory_per_gpu:
            kwargs["max_memory"] = {i: self.max_memory_per_gpu for i in range(self.num_gpus)}
        
        # Attention implementation
        if self.flash_attention:
            kwargs["attn_implementation"] = "flash_attention_2"
        
        return kwargs
    
    def get_quantization_config(self) -> Optional[Dict[str, Any]]:
        """
        Get quantization configuration for QLoRA.
        
        Returns:
            BitsAndBytesConfig parameters or None
        """
        if not self.use_lora:
            return None
        
        if hasattr(self, 'load_in_4bit') and self.load_in_4bit:
            return {
                "load_in_4bit": True,
                "bnb_4bit_quant_type": getattr(self, 'bnb_4bit_quant_type', 'nf4'),
                "bnb_4bit_use_double_quant": getattr(self, 'bnb_4bit_use_double_quant', True),
                "bnb_4bit_compute_dtype": getattr(self, 'bnb_4bit_compute_dtype', 'bfloat16'),
                "bnb_4bit_quant_storage": self.bnb_4bit_quant_storage,
            }
        elif self.load_in_8bit:
            return {
                "load_in_8bit": True
            }
        
        return None
    
    def get_lora_config(self) -> Dict[str, Any]:
        """
        Get LoRA configuration optimized for 72B models.
        
        Returns:
            LoRA configuration dictionary
        """
        # Default target modules for Qwen2.5-VL
        default_target_modules = [
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj"
        ]
        
        config = {
            "r": self.lora_r,
            "lora_alpha": self.lora_alpha,
            "lora_dropout": self.lora_dropout,
            "bias": "none",
            "task_type": "CAUSAL_LM",
            "target_modules": self.lora_target_modules or default_target_modules,
        }
        
        # Advanced LoRA features
        if self.use_rslora:
            config["use_rslora"] = True
        if self.use_dora:
            config["use_dora"] = True
        if self.lora_modules_to_save:
            config["modules_to_save"] = self.lora_modules_to_save
        
        return config
    
    def validate_distributed_config(self) -> List[str]:
        """
        Validate distributed training configuration.
        
        Returns:
            List of validation errors
        """
        errors = []
        
        if self.distributed_enabled:
            if self.num_gpus < 2:
                errors.append("Distributed training requires at least 2 GPUs")
            
            if self.distributed_strategy not in ["deepspeed", "fsdp", "ddp"]:
                errors.append(f"Unsupported distributed strategy: {self.distributed_strategy}")
            
            if self.model_size == "72b" and self.batch_size > 2:
                errors.append("Batch size > 2 may cause OOM with 72B model")
            
            # Check memory settings
            if self.load_in_8bit and hasattr(self, 'load_in_4bit') and self.load_in_4bit:
                errors.append("Cannot use both 8-bit and 4-bit quantization")
        
        return errors


def create_distributed_config_from_yaml(yaml_path: str, **overrides) -> DistributedTrainingConfig:
    """
    Create distributed training config from YAML file.
    
    Args:
        yaml_path: Path to YAML configuration file
        **overrides: Configuration overrides
        
    Returns:
        Configured DistributedTrainingConfig
    """
    import yaml
    from pathlib import Path
    
    # Load YAML config
    with open(yaml_path, 'r') as f:
        config_data = yaml.safe_load(f)
    
    # Flatten nested configuration
    flat_config = _flatten_yaml_config(config_data)
    
    # Apply overrides
    flat_config.update(overrides)
    
    # Create config
    config = DistributedTrainingConfig(**flat_config)
    
    # Validate
    errors = config.validate_distributed_config()
    if errors:
        error_msg = "\n".join(f"  - {error}" for error in errors)
        raise ValueError(f"Configuration errors:\n{error_msg}")
    
    return config


def _flatten_yaml_config(config_data: Dict[str, Any]) -> Dict[str, Any]:
    """Flatten nested YAML configuration for DistributedTrainingConfig."""
    flat_config = {}
    
    # Training section
    training = config_data.get('training', {})
    flat_config.update({
        'vlm_model_name': training.get('model_name'),
        'learning_rate': training.get('learning_rate', 1e-4),
        'batch_size': training.get('per_device_train_batch_size', 1),
        'gradient_accumulation_steps': training.get('gradient_accumulation_steps', 8),
        'max_epochs': training.get('num_train_epochs', 2),
        'warmup_ratio': training.get('warmup_ratio', 0.1),
        'weight_decay': training.get('weight_decay', 0.01),
        'gradient_checkpointing': training.get('gradient_checkpointing', True),
        'mixed_precision': training.get('bf16', True),
    })
    
    # Distributed section
    distributed = training.get('distributed', {})
    if distributed.get('enabled'):
        flat_config.update({
            'distributed_enabled': True,
            'distributed_strategy': distributed.get('strategy', 'deepspeed'),
            'num_gpus': distributed.get('num_gpus', 8),
        })
    
    # Quantization section
    quantization = training.get('quantization', {})
    if quantization.get('load_in_4bit'):
        flat_config.update({
            'load_in_4bit': True,
            'bnb_4bit_quant_type': quantization.get('bnb_4bit_quant_type', 'nf4'),
            'bnb_4bit_use_double_quant': quantization.get('bnb_4bit_use_double_quant', True),
            'bnb_4bit_compute_dtype': quantization.get('bnb_4bit_compute_dtype', 'bfloat16'),
        })
    elif quantization.get('load_in_8bit'):
        flat_config['load_in_8bit'] = True
    
    # LoRA section
    lora = training.get('lora', {})
    flat_config.update({
        'use_lora': True,  # Always use LoRA for 72B
        'lora_r': lora.get('r', 32),
        'lora_alpha': lora.get('alpha', 64),
        'lora_dropout': lora.get('dropout', 0.05),
        'lora_target_modules': lora.get('target_modules'),
    })
    
    # Data section
    data = config_data.get('data', {})
    flat_config.update({
        'dataset_path': data.get('dataset_path'),
        'max_sequence_length': data.get('max_seq_length', 2048),
    })
    
    # Output section
    output = config_data.get('output', {})
    flat_config.update({
        'output_dir': output.get('output_dir', './results'),
        'experiment_name': output.get('run_name'),
    })
    
    # Training method
    flat_config['training_method'] = config_data.get('training_method', 'sft')
    
    return flat_config 