"""
Model utilities for VLM training.

This module provides helper functions for loading, configuring, and managing
VLM models for training with different frameworks (TRL, Unsloth).
"""

import os
import torch
from pathlib import Path
from typing import Dict, Any, Optional, Union, Tuple, List
import logging
from dataclasses import dataclass
import yaml

# Framework imports (conditional)
try:
    from transformers import (
        AutoModelForVision2Seq, 
        AutoProcessor, 
        AutoTokenizer,
        AutoConfig,
        BitsAndBytesConfig
    )
except ImportError:
    raise ImportError("transformers is required for model utilities")

try:
    from peft import LoraConfig, get_peft_model, TaskType
    PEFT_AVAILABLE = True
except ImportError:
    PEFT_AVAILABLE = False
    LoraConfig = None

try:
    import unsloth
    from unsloth import FastVisionModel
    UNSLOTH_AVAILABLE = True
except ImportError:
    UNSLOTH_AVAILABLE = False

logger = logging.getLogger(__name__)


@dataclass
class ModelLoadConfig:
    """Configuration for model loading."""
    model_name: str
    torch_dtype: str = "bfloat16"
    device_map: str = "auto"
    trust_remote_code: bool = True
    attn_implementation: str = "flash_attention_2"
    
    # Quantization settings
    load_in_4bit: bool = False
    load_in_8bit: bool = False
    bnb_4bit_quant_type: str = "nf4"
    bnb_4bit_use_double_quant: bool = True
    bnb_4bit_compute_dtype: str = "bfloat16"
    
    # LoRA settings
    use_lora: bool = True
    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.1
    lora_bias: str = "none"
    lora_target_modules: Optional[List[str]] = None
    use_rslora: bool = False
    use_dora: bool = False
    
    # Framework-specific settings
    framework: Optional[str] = None  # "trl", "unsloth", or None (auto)
    max_seq_length: int = 2048


def get_torch_dtype(dtype_str: str) -> torch.dtype:
    """Convert string dtype to torch dtype."""
    dtype_map = {
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
        "float32": torch.float32,
        "fp16": torch.float16,
        "bf16": torch.bfloat16,
        "fp32": torch.float32,
    }
    
    if dtype_str.lower() in dtype_map:
        return dtype_map[dtype_str.lower()]
    else:
        logger.warning(f"Unknown dtype {dtype_str}, defaulting to bfloat16")
        return torch.bfloat16


def create_bnb_config(config: ModelLoadConfig) -> Optional[BitsAndBytesConfig]:
    """Create BitsAndBytesConfig for quantization."""
    if not (config.load_in_4bit or config.load_in_8bit):
        return None
    
    if config.load_in_4bit:
        return BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type=config.bnb_4bit_quant_type,
            bnb_4bit_use_double_quant=config.bnb_4bit_use_double_quant,
            bnb_4bit_compute_dtype=get_torch_dtype(config.bnb_4bit_compute_dtype),
        )
    elif config.load_in_8bit:
        return BitsAndBytesConfig(load_in_8bit=True)
    
    return None


def create_lora_config(config: ModelLoadConfig) -> Optional[LoraConfig]:
    """Create LoRA configuration."""
    if not config.use_lora or not PEFT_AVAILABLE:
        return None
    
    # Default target modules for common VLM architectures
    default_target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ]
    
    target_modules = config.lora_target_modules or default_target_modules
    
    lora_config = LoraConfig(
        r=config.lora_r,
        lora_alpha=config.lora_alpha,
        target_modules=target_modules,
        lora_dropout=config.lora_dropout,
        bias=config.lora_bias,
        task_type=TaskType.CAUSAL_LM,
        use_rslora=config.use_rslora,
        use_dora=config.use_dora,
    )
    
    return lora_config


def detect_framework() -> str:
    """Auto-detect available training framework."""
    if UNSLOTH_AVAILABLE:
        logger.info("Unsloth detected - using Unsloth for training")
        return "unsloth"
    else:
        logger.info("Using TRL for training")
        return "trl"


def load_vlm_model_trl(config: ModelLoadConfig) -> Tuple[Any, Any, Any]:
    """
    Load VLM model using TRL/transformers.
    
    Returns:
        Tuple of (model, tokenizer, processor)
    """
    logger.info(f"Loading VLM model {config.model_name} with TRL")
    
    # Create quantization config
    bnb_config = create_bnb_config(config)
    
    # Model loading arguments
    model_kwargs = {
        "torch_dtype": get_torch_dtype(config.torch_dtype),
        "device_map": config.device_map,
        "trust_remote_code": config.trust_remote_code,
        "attn_implementation": config.attn_implementation,
    }
    
    if bnb_config is not None:
        model_kwargs["quantization_config"] = bnb_config
    
    try:
        # Load model
        model = AutoModelForVision2Seq.from_pretrained(
            config.model_name,
            **model_kwargs
        )
        
        # Load processor and tokenizer
        processor = AutoProcessor.from_pretrained(
            config.model_name,
            trust_remote_code=config.trust_remote_code
        )
        tokenizer = processor.tokenizer if hasattr(processor, 'tokenizer') else processor
        
        # Apply LoRA if specified
        if config.use_lora:
            lora_config = create_lora_config(config)
            if lora_config is not None:
                model = get_peft_model(model, lora_config)
                logger.info(f"Applied LoRA with rank {config.lora_r}")
        
        logger.info("Model loaded successfully with TRL")
        return model, tokenizer, processor
        
    except Exception as e:
        logger.error(f"Failed to load model with TRL: {e}")
        raise


def load_vlm_model_unsloth(config: ModelLoadConfig) -> Tuple[Any, Any]:
    """
    Load VLM model using Unsloth.
    
    Returns:
        Tuple of (model, tokenizer)
    """
    if not UNSLOTH_AVAILABLE:
        raise ImportError("Unsloth is not available")
    
    logger.info(f"Loading VLM model {config.model_name} with Unsloth")
    
    try:
        # Unsloth loading arguments
        unsloth_kwargs = {
            "model_name": config.model_name,
            "max_seq_length": config.max_seq_length,
            "dtype": get_torch_dtype(config.torch_dtype),
            "load_in_4bit": config.load_in_4bit,
            "trust_remote_code": config.trust_remote_code,
        }
        
        # LoRA configuration for Unsloth
        if config.use_lora:
            unsloth_kwargs.update({
                "r": config.lora_r,
                "lora_alpha": config.lora_alpha,
                "lora_dropout": config.lora_dropout,
                "bias": config.lora_bias,
                "target_modules": config.lora_target_modules or [
                    "q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"
                ],
                "use_gradient_checkpointing": "unsloth",
                "random_state": 42,
                "use_rslora": config.use_rslora,
            })
        
        # Load model with Unsloth
        model, tokenizer = FastVisionModel.from_pretrained(**unsloth_kwargs)
        
        # Configure for training
        model = FastVisionModel.get_peft_model(
            model,
            r=config.lora_r,
            lora_alpha=config.lora_alpha,
            lora_dropout=config.lora_dropout,
            bias=config.lora_bias,
            target_modules=config.lora_target_modules or [
                "q_proj", "k_proj", "v_proj", "o_proj",
                "gate_proj", "up_proj", "down_proj"
            ],
            use_gradient_checkpointing="unsloth",
            random_state=42,
            use_rslora=config.use_rslora,
            max_seq_length=config.max_seq_length,
        )
        
        logger.info("Model loaded successfully with Unsloth")
        return model, tokenizer
        
    except Exception as e:
        logger.error(f"Failed to load model with Unsloth: {e}")
        raise


def load_vlm_model(config: ModelLoadConfig) -> Dict[str, Any]:
    """
    Load VLM model with automatic framework detection.
    
    Args:
        config: Model loading configuration
        
    Returns:
        Dictionary containing model, tokenizer, and processor (if available)
    """
    # Determine framework
    framework = config.framework or detect_framework()
    
    if framework == "unsloth" and UNSLOTH_AVAILABLE:
        model, tokenizer = load_vlm_model_unsloth(config)
        return {
            "model": model,
            "tokenizer": tokenizer,
            "processor": None,
            "framework": "unsloth"
        }
    else:
        model, tokenizer, processor = load_vlm_model_trl(config)
        return {
            "model": model,
            "tokenizer": tokenizer, 
            "processor": processor,
            "framework": "trl"
        }


def save_model_and_config(
    model: Any,
    tokenizer: Any,
    output_dir: Union[str, Path],
    config: ModelLoadConfig,
    processor: Optional[Any] = None
) -> None:
    """
    Save model, tokenizer, processor and configuration.
    
    Args:
        model: The trained model
        tokenizer: The tokenizer
        output_dir: Directory to save to
        config: Model configuration
        processor: Optional processor for VLM models
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    logger.info(f"Saving model to {output_dir}")
    
    try:
        # Save model
        model.save_pretrained(output_dir)
        
        # Save tokenizer
        tokenizer.save_pretrained(output_dir)
        
        # Save processor if available
        if processor is not None:
            processor.save_pretrained(output_dir)
        
        # Save configuration
        config_dict = {
            "model_name": config.model_name,
            "torch_dtype": config.torch_dtype,
            "framework": config.framework,
            "max_seq_length": config.max_seq_length,
            "lora_config": {
                "use_lora": config.use_lora,
                "r": config.lora_r,
                "alpha": config.lora_alpha,
                "dropout": config.lora_dropout,
                "target_modules": config.lora_target_modules,
            } if config.use_lora else None
        }
        
        config_path = output_dir / "model_config.yaml"
        with open(config_path, 'w') as f:
            yaml.dump(config_dict, f, default_flow_style=False)
        
        logger.info("Model saved successfully")
        
    except Exception as e:
        logger.error(f"Failed to save model: {e}")
        raise


def load_model_from_config(config_path: Union[str, Path]) -> Dict[str, Any]:
    """
    Load model from saved configuration.
    
    Args:
        config_path: Path to model config YAML
        
    Returns:
        Dictionary containing loaded model components
    """
    config_path = Path(config_path)
    
    if not config_path.exists():
        raise FileNotFoundError(f"Config file not found: {config_path}")
    
    # Load configuration
    with open(config_path, 'r') as f:
        config_dict = yaml.safe_load(f)
    
    # Create ModelLoadConfig
    config = ModelLoadConfig(
        model_name=str(config_path.parent),  # Use directory as model path
        torch_dtype=config_dict.get("torch_dtype", "bfloat16"),
        framework=config_dict.get("framework"),
        max_seq_length=config_dict.get("max_seq_length", 2048),
    )
    
    # Apply LoRA config if present
    if config_dict.get("lora_config"):
        lora_config = config_dict["lora_config"]
        config.use_lora = lora_config.get("use_lora", False)
        config.lora_r = lora_config.get("r", 16)
        config.lora_alpha = lora_config.get("alpha", 32)
        config.lora_dropout = lora_config.get("dropout", 0.1)
        config.lora_target_modules = lora_config.get("target_modules")
    
    return load_vlm_model(config)


def get_model_memory_usage(model: Any) -> Dict[str, float]:
    """
    Get memory usage information for a model.
    
    Args:
        model: PyTorch model
        
    Returns:
        Dictionary with memory usage in MB
    """
    if not torch.cuda.is_available():
        return {"error": "CUDA not available"}
    
    try:
        # Get model parameters memory
        param_size = sum(p.numel() * p.element_size() for p in model.parameters())
        param_size_mb = param_size / (1024 ** 2)
        
        # Get GPU memory usage
        gpu_memory = torch.cuda.memory_allocated() / (1024 ** 2)
        gpu_memory_max = torch.cuda.max_memory_allocated() / (1024 ** 2)
        gpu_memory_reserved = torch.cuda.memory_reserved() / (1024 ** 2)
        
        return {
            "model_parameters_mb": param_size_mb,
            "gpu_allocated_mb": gpu_memory,
            "gpu_max_allocated_mb": gpu_memory_max,
            "gpu_reserved_mb": gpu_memory_reserved,
        }
        
    except Exception as e:
        logger.error(f"Failed to get memory usage: {e}")
        return {"error": str(e)}


def print_model_info(model: Any, tokenizer: Any, config: ModelLoadConfig) -> None:
    """Print model information and memory usage."""
    logger.info("=== Model Information ===")
    logger.info(f"Model: {config.model_name}")
    logger.info(f"Framework: {config.framework}")
    logger.info(f"Max sequence length: {config.max_seq_length}")
    
    # Model parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    logger.info(f"Total parameters: {total_params:,}")
    logger.info(f"Trainable parameters: {trainable_params:,}")
    logger.info(f"Trainable ratio: {100 * trainable_params / total_params:.2f}%")
    
    # Memory usage
    memory_info = get_model_memory_usage(model)
    if "error" not in memory_info:
        logger.info(f"Model size: {memory_info['model_parameters_mb']:.1f} MB")
        logger.info(f"GPU memory allocated: {memory_info['gpu_allocated_mb']:.1f} MB")
        logger.info(f"GPU memory reserved: {memory_info['gpu_reserved_mb']:.1f} MB")
    
    # LoRA info
    if config.use_lora and hasattr(model, 'peft_config'):
        logger.info("=== LoRA Configuration ===")
        logger.info(f"LoRA rank: {config.lora_r}")
        logger.info(f"LoRA alpha: {config.lora_alpha}")
        logger.info(f"LoRA dropout: {config.lora_dropout}")
        logger.info(f"Target modules: {config.lora_target_modules}")
    
    logger.info("=" * 30)


def create_model_config_from_yaml(yaml_path: Union[str, Path]) -> ModelLoadConfig:
    """
    Create ModelLoadConfig from YAML training configuration.
    
    Args:
        yaml_path: Path to training configuration YAML
        
    Returns:
        ModelLoadConfig instance
    """
    yaml_path = Path(yaml_path)
    
    with open(yaml_path, 'r') as f:
        config = yaml.safe_load(f)
    
    # Extract model configuration
    model_config = config.get("model", {})
    lora_config = config.get("lora", {})
    
    # Get model name from config - fail if not provided
    vlm_model_name = model_config.get("vlm_model_name")
    if not vlm_model_name:
        raise ValueError("vlm_model_name is required in configuration")
    
    return ModelLoadConfig(
        model_name=vlm_model_name,
        torch_dtype=model_config.get("torch_dtype", "bfloat16"),
        device_map=model_config.get("device_map", "auto"),
        trust_remote_code=model_config.get("trust_remote_code", True),
        attn_implementation=model_config.get("attn_implementation", "flash_attention_2"),
        
        # Quantization
        load_in_4bit=model_config.get("load_in_4bit", False),
        load_in_8bit=model_config.get("load_in_8bit", False),
        bnb_4bit_quant_type=model_config.get("bnb_4bit_quant_type", "nf4"),
        bnb_4bit_use_double_quant=model_config.get("bnb_4bit_use_double_quant", True),
        bnb_4bit_compute_dtype=model_config.get("bnb_4bit_compute_dtype", "bfloat16"),
        
        # LoRA
        use_lora=lora_config.get("use_lora", True),
        lora_r=lora_config.get("r", 16),
        lora_alpha=lora_config.get("lora_alpha", 32),
        lora_dropout=lora_config.get("lora_dropout", 0.1),
        lora_bias=lora_config.get("bias", "none"),
        lora_target_modules=lora_config.get("target_modules"),
        use_rslora=lora_config.get("use_rslora", False),
        use_dora=lora_config.get("use_dora", False),
        
        # Framework and sequence length
        framework=config.get("framework"),
        max_seq_length=config.get("training", {}).get("max_seq_length", 2048),
    ) 