"""
Distributed TRL adapter for 72B VLM training.

This module extends the existing TRLTrainingAdapter to support distributed training
with DeepSpeed, FSDP, and other optimization strategies for large models.
"""

import os
import json
import time
import tempfile
from pathlib import Path
from typing import Dict, Any, Optional, Union

import torch
import torch.distributed as dist
from transformers import PreTrainedModel, TrainingArguments

from .trl_adapter import TRLTrainingAdapter
from .distributed_config import DistributedTrainingConfig
from .base_trainer import TrainingResult
from ...utils.logging import get_logger


class DistributedTRLAdapter(TRLTrainingAdapter):
    """
    Extended TRL adapter for distributed 72B VLM training.
    
    This adapter extends the existing TRLTrainingAdapter with:
    - DeepSpeed ZeRO-3 integration
    - FSDP support  
    - Multi-GPU QLoRA optimization
    - Memory-efficient distributed training
    """
    
    def __init__(self):
        """Initialize distributed TRL adapter."""
        super().__init__()
        self.framework_name = "distributed_trl"
        self.logger = get_logger(self.__class__.__name__)
        
        # Distributed state
        self.is_distributed = False
        self.world_size = 1
        self.rank = 0
        self.local_rank = 0
        
        # Check if we're in a distributed environment
        self._detect_distributed_environment()
    
    def _detect_distributed_environment(self):
        """Detect if we're running in a distributed environment."""
        if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
            self.is_distributed = True
            self.rank = int(os.environ['RANK'])
            self.world_size = int(os.environ['WORLD_SIZE'])
            self.local_rank = int(os.environ.get('LOCAL_RANK', 0))
            
            # Initialize process group if not already done
            if not dist.is_initialized():
                dist.init_process_group(backend='nccl')
            
            # Set CUDA device
            if torch.cuda.is_available():
                torch.cuda.set_device(self.local_rank)
            
            self.logger.info(
                f"Distributed environment detected: rank={self.rank}, "
                f"world_size={self.world_size}, local_rank={self.local_rank}"
            )
    
    def is_available(self) -> bool:
        """Check if distributed TRL is available."""
        # Check base TRL availability first
        if not super().is_available():
            return False
        
        # Check for distributed training dependencies
        try:
            import deepspeed
            from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
            return True
        except ImportError as e:
            self.logger.warning(f"Distributed training dependencies not available: {e}")
            return False
    
    def initialize_model(
        self,
        config: DistributedTrainingConfig,
        scaffold: 'BaseReasoningScaffold'
    ) -> PreTrainedModel:
        """
        Initialize model for distributed training.
        
        Args:
            config: Distributed training configuration
            scaffold: Reasoning scaffold to train
            
        Returns:
            Model prepared for distributed training
        """
        if not isinstance(config, DistributedTrainingConfig):
            raise TypeError("DistributedTRLAdapter requires DistributedTrainingConfig")
        
        self.logger.info(f"Initializing {config.model_size} model for distributed training")
        
        # Use parent class implementation with distributed-aware modifications
        model = super().initialize_model(config, scaffold)
        
        self.current_config = config
        return model
    
    def prepare_trainer(
        self,
        model: PreTrainedModel,
        train_dataset,
        eval_dataset,
        config: DistributedTrainingConfig
    ):
        """
        Prepare distributed trainer.
        
        Args:
            model: Initialized model
            train_dataset: Training dataset
            eval_dataset: Evaluation dataset (optional)
            config: Distributed training configuration
            
        Returns:
            Configured trainer
        """
        # Create DeepSpeed config if using DeepSpeed
        deepspeed_config_path = None
        if config.distributed_enabled and config.distributed_strategy == "deepspeed":
            deepspeed_config_path = self._create_deepspeed_config(config)
        
        # Create training arguments with distributed settings
        training_args = self._create_distributed_training_args(config, deepspeed_config_path)
        
        # Create trainer using parent class method
        trainer = self._create_sft_trainer(model, train_dataset, eval_dataset, training_args, config)
        
        self.trainer = trainer
        return trainer
    
    def _create_distributed_training_args(
        self,
        config: DistributedTrainingConfig,
        deepspeed_config_path: Optional[str] = None
    ) -> TrainingArguments:
        """Create training arguments optimized for distributed 72B training."""
        
        # Base arguments
        args_dict = {
            'output_dir': config.output_dir,
            'per_device_train_batch_size': config.batch_size,
            'per_device_eval_batch_size': config.batch_size,
            'gradient_accumulation_steps': config.gradient_accumulation_steps,
            'learning_rate': config.learning_rate,
            'weight_decay': config.weight_decay,
            'warmup_ratio': config.warmup_ratio,
            'num_train_epochs': config.max_epochs,
            'logging_steps': config.logging_steps,
            'save_steps': config.save_steps,
            'eval_steps': config.eval_steps,
            'save_strategy': 'steps',
            'evaluation_strategy': 'steps' if eval_dataset else 'no',
            'load_best_model_at_end': True if eval_dataset else False,
            'metric_for_best_model': 'eval_loss',
            'greater_is_better': False,
            'save_total_limit': config.keep_last_n_checkpoints,
            
            # Precision settings
            'fp16': False,  # Use bf16 for 72B stability
            'bf16': config.mixed_precision,
            'tf32': True,
            
            # Memory optimization
            'gradient_checkpointing': config.gradient_checkpointing,
            'dataloader_pin_memory': config.dataloader_pin_memory,
            'dataloader_num_workers': config.dataloader_num_workers,
            'remove_unused_columns': config.remove_unused_columns,
            
            # Optimizer settings optimized for 72B
            'optim': config.optimizer_type,
            'lr_scheduler_type': config.scheduler_type,
            'max_grad_norm': 1.0,
            
            # Distributed settings
            'ddp_backend': 'nccl' if config.distributed_enabled else None,
            'deepspeed': deepspeed_config_path,
            
            # Advanced settings
            'group_by_length': False,  # Better for conversation data
            'length_column_name': 'length',
            
            # Reporting
            'report_to': getattr(config, 'report_to', None),
            'run_name': getattr(config, 'run_name', None),
            
            # NEFTune
            'neftune_noise_alpha': config.neftune_noise_alpha,
            
            # Reproducibility
            'seed': getattr(config, 'seed', 42),
            'data_seed': getattr(config, 'seed', 42),
        }
        
        # FSDP configuration
        if config.distributed_enabled and config.distributed_strategy == "fsdp":
            fsdp_config = config.to_fsdp_config()
            args_dict.update(fsdp_config)
        
        return TrainingArguments(**args_dict)
    
    def _create_deepspeed_config(self, config: DistributedTrainingConfig) -> str:
        """Create DeepSpeed configuration file."""
        deepspeed_config = config.to_deepspeed_config()
        
        # Create temporary config file
        with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
            json.dump(deepspeed_config, f, indent=2)
            config_path = f.name
        
        self.logger.info(f"Created DeepSpeed config: {config_path}")
        if self.rank == 0:  # Only log on main process
            self.logger.debug(f"DeepSpeed config: {json.dumps(deepspeed_config, indent=2)}")
        
        return config_path
    
    def train(
        self,
        trainer,
        config: DistributedTrainingConfig
    ) -> TrainingResult:
        """
        Execute distributed training.
        
        Args:
            trainer: Configured trainer
            config: Distributed training configuration
            
        Returns:
            Training results
        """
        try:
            start_time = time.time()
            
            # Add early stopping callback if configured
            if hasattr(config, 'early_stopping_patience') and config.early_stopping_patience > 0:
                from transformers import EarlyStoppingCallback
                early_stopping = EarlyStoppingCallback(
                    early_stopping_patience=config.early_stopping_patience,
                    early_stopping_threshold=config.early_stopping_threshold
                )
                trainer.add_callback(early_stopping)
            
            # Log training start (only on main process)
            if self.rank == 0:
                self.logger.info("Starting distributed training...")
                self.logger.info(f"Model parameters: {self.count_parameters(trainer.model)}")
            
            # Start training
            resume_checkpoint = getattr(config, 'resume_from_checkpoint', None)
            if resume_checkpoint:
                self.logger.info(f"Resuming training from checkpoint: {resume_checkpoint}")
                train_result = trainer.train(resume_from_checkpoint=resume_checkpoint)
            else:
                train_result = trainer.train()
            
            training_time = time.time() - start_time
            
            # Process results (only on main process)
            if self.rank == 0:
                final_loss = train_result.training_loss
                
                # Extract training history
                train_losses = []
                eval_metrics = None
                
                if hasattr(trainer.state, 'log_history'):
                    for log_entry in trainer.state.log_history:
                        if 'train_loss' in log_entry:
                            train_losses.append(log_entry['train_loss'])
                        if 'eval_loss' in log_entry:
                            if eval_metrics is None:
                                eval_metrics = {'loss': []}
                            eval_metrics['loss'].append(log_entry['eval_loss'])
                
                # Save model
                model_path = Path(config.output_dir) / "final_model"
                trainer.save_model(model_path)
                
                # Save adapter if using PEFT
                adapter_path = None
                if config.use_lora:
                    adapter_path = Path(config.output_dir) / "adapter"
                    if hasattr(trainer.model, 'save_pretrained'):
                        trainer.model.save_pretrained(adapter_path)
                
                # Calculate best eval metric
                best_eval_metric = None
                if eval_metrics and 'loss' in eval_metrics:
                    best_eval_metric = min(eval_metrics['loss'])
                
                result = TrainingResult(
                    final_loss=final_loss,
                    best_eval_metric=best_eval_metric,
                    training_time=training_time,
                    model_path=str(model_path),
                    adapter_path=str(adapter_path) if adapter_path else None,
                    train_losses=train_losses,
                    eval_metrics=eval_metrics,
                    config_used=config,
                    framework="distributed_trl",
                    success=True
                )
                
                self.logger.info("Distributed training completed successfully!")
                self.logger.info(f"Final training loss: {final_loss:.4f}")
                self.logger.info(f"Total training time: {training_time:.2f} seconds")
                
                return result
            else:
                # Non-main processes return a minimal result
                return TrainingResult(
                    final_loss=0.0,
                    training_time=training_time,
                    framework="distributed_trl",
                    success=True
                )
            
        except Exception as e:
            training_time = time.time() - start_time
            error_msg = f"Distributed training failed: {str(e)}"
            
            if self.rank == 0:
                self.logger.error(error_msg, exc_info=True)
            
            return TrainingResult(
                final_loss=float('inf'),
                training_time=training_time,
                config_used=config,
                framework="distributed_trl",
                success=False,
                error_message=error_msg
            )
    
    def cleanup(self):
        """Cleanup distributed training resources."""
        try:
            # Cleanup distributed process group
            if self.is_distributed and dist.is_initialized():
                dist.destroy_process_group()
                self.logger.info("Cleaned up distributed process group")
        except Exception as e:
            self.logger.warning(f"Error during cleanup: {e}")
    
    def __del__(self):
        """Destructor to ensure cleanup."""
        try:
            self.cleanup()
        except:
            pass  # Ignore errors during destruction 