"""
Unsloth training adapter for framework-agnostic VLM training.

This module implements Unsloth-specific training logic with 90% memory reduction
and support for GRPO (Gradient-free Reward Policy Optimization).
"""

import time
from typing import Dict, Any, Optional, List, Union
from pathlib import Path
import torch
from transformers import PreTrainedModel

from .base_trainer import BaseTrainingAdapter, TrainingConfig, TrainingResult


class UnslothTrainingAdapter(BaseTrainingAdapter):
    """
    Unsloth training adapter with memory-efficient training.
    
    This adapter provides integration with Unsloth for fast, memory-efficient
    training of VLM components, featuring:
    - 90% memory reduction compared to standard training
    - GRPO (Gradient-free Reward Policy Optimization)
    - Native vLLM integration for inference
    """
    
    def __init__(self):
        """Initialize Unsloth training adapter."""
        super().__init__("unsloth")
        self._unsloth_available = None
        self._required_modules = [
            "unsloth",
            "trl"  # Unsloth uses TRL under the hood
        ]
    
    def is_available(self) -> bool:
        """
        Check if Unsloth and dependencies are available.
        
        Returns:
            True if Unsloth can be used, False otherwise
        """
        if self._unsloth_available is not None:
            return self._unsloth_available
        
        try:
            from unsloth import FastLanguageModel
            import trl
            
            self._unsloth_available = True
            return True
            
        except ImportError as e:
            print(f"Unsloth not available: {e}")
            self._unsloth_available = False
            return False
    
    def initialize_model(
        self,
        config: TrainingConfig,
        scaffold: 'BaseReasoningScaffold'
    ) -> PreTrainedModel:
        """
        Initialize model for Unsloth training.
        
        Args:
            config: Training configuration
            scaffold: Reasoning scaffold to train
            
        Returns:
            Model prepared for Unsloth training
        """
        if not self.is_available():
            raise RuntimeError("Unsloth is not available in this environment")
        
        from unsloth import FastLanguageModel
        
        # Load model with Unsloth optimizations
        model, tokenizer = FastLanguageModel.from_pretrained(
            model_name=config.vlm_model_name,
            max_seq_length=config.max_sequence_length,
            dtype=torch.float16 if config.mixed_precision else None,
            load_in_4bit=True,  # Unsloth's 4-bit quantization
        )
        
        # Set up LoRA with Unsloth optimizations
        if config.use_lora:
            model = FastLanguageModel.get_peft_model(
                model,
                r=config.lora_r,
                lora_alpha=config.lora_alpha,
                lora_dropout=config.lora_dropout,
                target_modules=self._get_target_modules(model, config),
                bias="none",
                use_gradient_checkpointing="unsloth" if config.gradient_checkpointing else False,
                random_state=42,
                use_rslora=False,  # Rank Stabilized LoRA
                loftq_config=None,
            )
        
        # Set up parameter freezing
        self.setup_freezing(model, config)
        
        # Store tokenizer for later use
        self.tokenizer = tokenizer
        self.model = model
        self.current_config = config
        
        return model
    
    def prepare_trainer(
        self,
        model: PreTrainedModel,
        train_dataset,
        eval_dataset,
        config: TrainingConfig
    ):
        """
        Prepare Unsloth trainer.
        
        Args:
            model: Initialized Unsloth model
            train_dataset: Training dataset
            eval_dataset: Evaluation dataset (optional)
            config: Training configuration
            
        Returns:
            Unsloth trainer object
        """
        from transformers import TrainingArguments
        from trl import SFTTrainer
        from unsloth import is_bfloat16_supported
        
        # Set up training arguments optimized for Unsloth
        training_args = TrainingArguments(
            output_dir=config.output_dir,
            per_device_train_batch_size=config.batch_size,
            per_device_eval_batch_size=config.batch_size,
            gradient_accumulation_steps=config.gradient_accumulation_steps,
            warmup_steps=max(1, int(config.warmup_ratio * config.max_epochs * 100)),  # Estimate steps
            learning_rate=config.learning_rate,
            weight_decay=config.weight_decay,
            num_train_epochs=config.max_epochs,
            
            # Unsloth optimizations
            fp16=not is_bfloat16_supported() and config.mixed_precision,
            bf16=is_bfloat16_supported() and config.mixed_precision,
            logging_steps=config.logging_steps,
            optim="adamw_8bit",  # 8-bit optimizer for memory efficiency
            
            # Saving and evaluation
            save_steps=config.save_steps,
            eval_steps=config.eval_steps if eval_dataset else None,
            evaluation_strategy="steps" if eval_dataset else "no",
            save_strategy="steps",
            load_best_model_at_end=True if eval_dataset else False,
            metric_for_best_model="eval_loss" if eval_dataset else None,
            greater_is_better=False,
            
            # Memory optimizations
            dataloader_pin_memory=False,
            group_by_length=True,  # Group sequences by length for efficiency
            remove_unused_columns=False,
        )
        
        # Add framework-specific arguments
        if config.framework_config:
            for key, value in config.framework_config.items():
                if hasattr(training_args, key):
                    setattr(training_args, key, value)
        
        # Choose trainer based on training method
        if config.training_method == "sft":
            trainer = self._create_sft_trainer(
                model, train_dataset, eval_dataset, training_args, config
            )
        elif config.training_method == "grpo":
            trainer = self._create_grpo_trainer(
                model, train_dataset, eval_dataset, training_args, config
            )
        else:
            raise ValueError(f"Unsupported training method for Unsloth: {config.training_method}")
        
        self.trainer = trainer
        return trainer
    
    def train(
        self,
        trainer,
        config: TrainingConfig
    ) -> TrainingResult:
        """
        Execute Unsloth training.
        
        Args:
            trainer: Unsloth trainer object
            config: Training configuration
            
        Returns:
            Training result with metrics and artifacts
        """
        start_time = time.time()
        
        try:
            # Show memory usage before training
            gpu_stats = torch.cuda.get_device_properties(0)
            start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
            max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
            print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
            print(f"Memory before training = {start_gpu_memory} GB.")
            
            # Start training
            resume_checkpoint = getattr(config, 'resume_from_checkpoint', None)
            if resume_checkpoint:
                print(f"Resuming training from checkpoint: {resume_checkpoint}")
                train_result = trainer.train(resume_from_checkpoint=resume_checkpoint)
            else:
                train_result = trainer.train()
            
            # Show memory usage after training
            used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
            used_memory_percent = round(used_memory / max_memory * 100, 3)
            print(f"Peak memory during training = {used_memory} GB ({used_memory_percent}%).")
            
            # Get training metrics
            final_loss = train_result.training_loss
            train_losses = []
            eval_metrics = {}
            
            # Extract training history if available
            if hasattr(trainer.state, 'log_history'):
                for log_entry in trainer.state.log_history:
                    if 'train_loss' in log_entry:
                        train_losses.append(log_entry['train_loss'])
                    
                    # Collect eval metrics
                    for key, value in log_entry.items():
                        if key.startswith('eval_'):
                            metric_name = key[5:]  # Remove 'eval_' prefix
                            if metric_name not in eval_metrics:
                                eval_metrics[metric_name] = []
                            eval_metrics[metric_name].append(value)
            
            # Save model in Unsloth format
            model_path = str(Path(config.output_dir) / "final_model")
            trainer.save_model(model_path)
            
            # Save adapter using Unsloth's method
            adapter_path = ""
            if config.use_lora:
                adapter_path = self.save_adapter(trainer.model, config.output_dir, config)
            
            # Calculate training time
            training_time = time.time() - start_time
            
            # Get best eval metric
            best_eval_metric = None
            if eval_metrics and 'loss' in eval_metrics:
                best_eval_metric = min(eval_metrics['loss'])
            
            return TrainingResult(
                final_loss=final_loss,
                best_eval_metric=best_eval_metric,
                training_time=training_time,
                model_path=model_path,
                adapter_path=adapter_path,
                train_losses=train_losses,
                eval_metrics=eval_metrics if eval_metrics else None,
                config_used=config,
                framework="unsloth",
                success=True
            )
            
        except Exception as e:
            training_time = time.time() - start_time
            
            return TrainingResult(
                final_loss=float('inf'),
                training_time=training_time,
                config_used=config,
                framework="unsloth",
                success=False,
                error_message=str(e)
            )
    
    def _create_sft_trainer(self, model, train_dataset, eval_dataset, training_args, config):
        """Create Unsloth SFT trainer."""
        from trl import SFTTrainer
        from unsloth.chat_templates import get_chat_template
        
        # Set up chat template if available
        try:
            self.tokenizer = get_chat_template(
                self.tokenizer,
                chat_template="chatml"  # Can be configured
            )
        except:
            pass  # Use default tokenizer
        
        return SFTTrainer(
            model=model,
            tokenizer=self.tokenizer,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            args=training_args,
            dataset_text_field="text",  # Assuming text field in dataset
            max_seq_length=config.max_sequence_length,
            dataset_num_proc=2,  # Number of processes for dataset processing
        )
    
    def _create_grpo_trainer(self, model, train_dataset, eval_dataset, training_args, config):
        """Create GRPO trainer for reward-based training."""
        # GRPO is Unsloth's gradient-free reward optimization
        # This is a placeholder - actual implementation would depend on reward function
        
        # For now, fall back to SFT trainer
        # In practice, you'd implement custom reward-based training here
        print("Warning: GRPO trainer not fully implemented, falling back to SFT")
        return self._create_sft_trainer(model, train_dataset, eval_dataset, training_args, config)
    
    def _get_target_modules(self, model, config: TrainingConfig) -> List[str]:
        """
        Get appropriate LoRA target modules for Unsloth.
        
        Args:
            model: Model to analyze
            config: Training configuration
            
        Returns:
            List of module names to target with LoRA
        """
        # Unsloth has optimized target modules for different architectures
        target_modules_map = {
            "llama": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
            "qwen": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], 
            "mistral": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
            "gemma": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
            "phi": ["q_proj", "k_proj", "v_proj", "dense"],
        }
        
        # Detect model architecture
        model_name = config.vlm_model_name.lower()
        
        for arch, modules in target_modules_map.items():
            if arch in model_name:
                return modules
        
        # Default fallback optimized for Unsloth
        return ["q_proj", "k_proj", "v_proj", "o_proj"]
    
    def save_adapter(
        self,
        model: PreTrainedModel,
        output_dir: Union[str, Path],
        config: TrainingConfig
    ) -> str:
        """
        Save LoRA adapter using Unsloth's optimized saving.
        
        Args:
            model: Trained Unsloth model with LoRA
            output_dir: Directory to save to
            config: Training configuration
            
        Returns:
            Path to saved adapter
        """
        if not config.use_lora:
            return ""
        
        try:
            from unsloth import FastLanguageModel
            
            output_dir = Path(output_dir)
            adapter_path = output_dir / "unsloth_adapter"
            adapter_path.mkdir(parents=True, exist_ok=True)
            
            # Use Unsloth's optimized saving
            model.save_pretrained(str(adapter_path))
            self.tokenizer.save_pretrained(str(adapter_path))
            
            # Also save in standard format for compatibility
            model.save_pretrained_merged(
                str(adapter_path / "merged"),
                tokenizer=self.tokenizer,
                save_method="merged_16bit",  # Memory efficient merged saving
            )
            
            return str(adapter_path)
                
        except Exception as e:
            print(f"Error saving Unsloth adapter: {e}")
            # Fallback to standard saving
            return super().save_adapter(model, output_dir, config)
    
    def load_adapter(
        self,
        model: PreTrainedModel,
        adapter_path: Union[str, Path]
    ) -> PreTrainedModel:
        """
        Load LoRA adapter using Unsloth's optimized loading.
        
        Args:
            model: Base model
            adapter_path: Path to adapter weights
            
        Returns:
            Model with loaded adapter
        """
        try:
            from unsloth import FastLanguageModel
            
            # Load using Unsloth's optimized loading
            model, tokenizer = FastLanguageModel.from_pretrained(
                model_name=str(adapter_path),
                max_seq_length=self.current_config.max_sequence_length if self.current_config else 2048,
                dtype=torch.float16,
                load_in_4bit=True,
            )
            
            self.tokenizer = tokenizer
            return model
            
        except Exception as e:
            print(f"Error loading Unsloth adapter: {e}")
            return model
    
    def validate_config(self, config: TrainingConfig) -> List[str]:
        """
        Validate Unsloth-specific configuration.
        
        Args:
            config: Configuration to validate
            
        Returns:
            List of validation error messages
        """
        errors = super().validate_config(config)
        
        # Unsloth-specific validation
        if config.training_method not in ["sft", "grpo"]:
            errors.append(f"Unsloth does not support training method: {config.training_method}")
        
        if config.max_sequence_length > 32768:
            errors.append("Unsloth may have issues with sequence lengths > 32K")
        
        # Unsloth works best with LoRA
        if not config.use_lora:
            errors.append("Warning: Unsloth is optimized for LoRA training")
        
        return errors
    
    def export_for_vllm(
        self,
        model: PreTrainedModel,
        output_dir: Union[str, Path],
        config: TrainingConfig
    ) -> str:
        """
        Export trained model for vLLM inference.
        
        This is a key feature of Unsloth - seamless export to vLLM.
        
        Args:
            model: Trained model
            output_dir: Directory to export to
            config: Training configuration
            
        Returns:
            Path to exported model
        """
        try:
            from unsloth import FastLanguageModel
            
            output_dir = Path(output_dir)
            vllm_path = output_dir / "vllm_export"
            vllm_path.mkdir(parents=True, exist_ok=True)
            
            # Export merged model for vLLM
            model.save_pretrained_merged(
                str(vllm_path),
                tokenizer=self.tokenizer,
                save_method="merged_16bit",
            )
            
            print(f"Model exported for vLLM to: {vllm_path}")
            return str(vllm_path)
            
        except Exception as e:
            print(f"Error exporting for vLLM: {e}")
            return ""
    
    def get_memory_info(self) -> Dict[str, float]:
        """
        Get detailed memory information for Unsloth.
        
        Returns:
            Dictionary with memory statistics including Unsloth optimizations
        """
        base_info = super().get_memory_info()
        
        if torch.cuda.is_available():
            gpu_stats = torch.cuda.get_device_properties(0)
            max_memory = gpu_stats.total_memory / 1e9
            
            base_info.update({
                'max_memory_gb': max_memory,
                'memory_efficiency': f"{(base_info['reserved_gb'] / max_memory) * 100:.1f}%",
                'unsloth_optimized': True
            })
        
        return base_info 