"""
Training utilities for VLM training system.

This module provides utility functions for training management,
monitoring, checkpointing, and optimization.
"""

import os
import gc
import json
import time
import logging
import shutil
from pathlib import Path
from typing import Dict, Any, List, Optional, Union, Tuple
from dataclasses import dataclass, field
from datetime import datetime, timedelta
import yaml

import torch
import numpy as np
from transformers import get_scheduler

logger = logging.getLogger(__name__)


@dataclass
class TrainingMetrics:
    """Training metrics tracking."""
    step: int
    epoch: int
    loss: float
    learning_rate: float
    grad_norm: Optional[float] = None
    memory_usage: Optional[Dict[str, float]] = None
    throughput: Optional[float] = None  # samples/second
    timestamp: str = field(default_factory=lambda: datetime.now().isoformat())


@dataclass
class CheckpointInfo:
    """Checkpoint information."""
    checkpoint_path: str
    step: int
    epoch: int
    loss: float
    timestamp: str
    model_config: Dict[str, Any]
    training_config: Dict[str, Any]


class TrainingMonitor:
    """Monitors training progress and metrics."""
    
    def __init__(self, output_dir: Union[str, Path], log_interval: int = 10):
        """
        Initialize training monitor.
        
        Args:
            output_dir: Directory to save monitoring data
            log_interval: Logging interval in steps
        """
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.log_interval = log_interval
        
        self.metrics_history = []
        self.start_time = time.time()
        self.last_log_time = time.time()
        self.total_samples = 0
        
        # Setup metrics file
        self.metrics_file = self.output_dir / "training_metrics.jsonl"
    
    def log_metrics(self, metrics: TrainingMetrics) -> None:
        """Log training metrics."""
        self.metrics_history.append(metrics)
        
        # Save to file
        with open(self.metrics_file, 'a') as f:
            json.dump(metrics.__dict__, f)
            f.write('\n')
        
        # Log to console if needed
        if metrics.step % self.log_interval == 0:
            self._log_to_console(metrics)
    
    def _log_to_console(self, metrics: TrainingMetrics) -> None:
        """Log metrics to console."""
        elapsed = time.time() - self.start_time
        elapsed_str = str(timedelta(seconds=int(elapsed)))
        
        log_msg = (
            f"Step {metrics.step:>6} | "
            f"Epoch {metrics.epoch:>3} | "
            f"Loss {metrics.loss:.4f} | "
            f"LR {metrics.learning_rate:.2e} | "
            f"Elapsed {elapsed_str}"
        )
        
        if metrics.grad_norm is not None:
            log_msg += f" | GradNorm {metrics.grad_norm:.3f}"
        
        if metrics.throughput is not None:
            log_msg += f" | {metrics.throughput:.1f} samples/s"
        
        logger.info(log_msg)
    
    def get_average_loss(self, last_n_steps: int = 100) -> float:
        """Get average loss over last N steps."""
        if not self.metrics_history:
            return float('inf')
        
        recent_metrics = self.metrics_history[-last_n_steps:]
        losses = [m.loss for m in recent_metrics]
        return sum(losses) / len(losses)
    
    def save_summary(self) -> None:
        """Save training summary."""
        if not self.metrics_history:
            return
        
        total_time = time.time() - self.start_time
        final_metrics = self.metrics_history[-1]
        
        summary = {
            "training_completed": True,
            "total_steps": final_metrics.step,
            "total_epochs": final_metrics.epoch,
            "final_loss": final_metrics.loss,
            "total_training_time": total_time,
            "average_loss_last_100": self.get_average_loss(100),
            "total_samples_processed": self.total_samples,
            "throughput_samples_per_second": self.total_samples / total_time if total_time > 0 else 0,
            "completed_at": datetime.now().isoformat()
        }
        
        summary_file = self.output_dir / "training_summary.json"
        with open(summary_file, 'w') as f:
            json.dump(summary, f, indent=2)
        
        logger.info(f"Training summary saved to {summary_file}")


class CheckpointManager:
    """Manages model checkpoints and saving."""
    
    def __init__(
        self, 
        output_dir: Union[str, Path],
        save_total_limit: int = 3,
        save_strategy: str = "steps",
        save_steps: int = 1000
    ):
        """
        Initialize checkpoint manager.
        
        Args:
            output_dir: Output directory for checkpoints
            save_total_limit: Maximum number of checkpoints to keep
            save_strategy: Save strategy ("steps", "epoch", "no")
            save_steps: Steps between saves
        """
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.save_total_limit = save_total_limit
        self.save_strategy = save_strategy
        self.save_steps = save_steps
        
        self.checkpoints = []
        self.checkpoint_info_file = self.output_dir / "checkpoint_info.json"
        
        # Load existing checkpoint info
        self._load_checkpoint_info()
    
    def should_save(self, step: int, epoch: int) -> bool:
        """Check if we should save a checkpoint."""
        if self.save_strategy == "no":
            return False
        elif self.save_strategy == "steps":
            return step % self.save_steps == 0
        elif self.save_strategy == "epoch":
            # Save at end of each epoch (implementation dependent)
            return True
        return False
    
    def save_checkpoint(
        self,
        model: Any,
        tokenizer: Any,
        step: int,
        epoch: int,
        loss: float,
        model_config: Dict[str, Any],
        training_config: Dict[str, Any],
        processor: Optional[Any] = None
    ) -> str:
        """
        Save model checkpoint.
        
        Args:
            model: Model to save
            tokenizer: Tokenizer to save
            step: Current training step
            epoch: Current epoch
            loss: Current loss
            model_config: Model configuration
            training_config: Training configuration
            processor: Optional processor to save
            
        Returns:
            Path to saved checkpoint
        """
        checkpoint_name = f"checkpoint-{step}"
        checkpoint_dir = self.output_dir / checkpoint_name
        checkpoint_dir.mkdir(exist_ok=True)
        
        logger.info(f"Saving checkpoint to {checkpoint_dir}")
        
        try:
            # Save model
            model.save_pretrained(checkpoint_dir)
            
            # Save tokenizer
            tokenizer.save_pretrained(checkpoint_dir)
            
            # Save processor if available
            if processor is not None:
                processor.save_pretrained(checkpoint_dir)
            
            # Save configuration
            config_file = checkpoint_dir / "training_config.yaml"
            with open(config_file, 'w') as f:
                yaml.dump(training_config, f, default_flow_style=False)
            
            # Save training state
            training_state = {
                "step": step,
                "epoch": epoch,
                "loss": loss,
                "timestamp": datetime.now().isoformat(),
                "model_config": model_config
            }
            
            state_file = checkpoint_dir / "training_state.json"
            with open(state_file, 'w') as f:
                json.dump(training_state, f, indent=2)
            
            # Update checkpoint tracking
            checkpoint_info = CheckpointInfo(
                checkpoint_path=str(checkpoint_dir),
                step=step,
                epoch=epoch,
                loss=loss,
                timestamp=datetime.now().isoformat(),
                model_config=model_config,
                training_config=training_config
            )
            
            self.checkpoints.append(checkpoint_info)
            self._cleanup_old_checkpoints()
            self._save_checkpoint_info()
            
            logger.info(f"Checkpoint saved successfully: {checkpoint_dir}")
            return str(checkpoint_dir)
            
        except Exception as e:
            logger.error(f"Failed to save checkpoint: {e}")
            # Cleanup partial checkpoint
            if checkpoint_dir.exists():
                shutil.rmtree(checkpoint_dir)
            raise
    
    def _cleanup_old_checkpoints(self) -> None:
        """Remove old checkpoints beyond save_total_limit."""
        if len(self.checkpoints) <= self.save_total_limit:
            return
        
        # Sort by step and remove oldest
        self.checkpoints.sort(key=lambda x: x.step)
        checkpoints_to_remove = self.checkpoints[:-self.save_total_limit]
        
        for checkpoint_info in checkpoints_to_remove:
            checkpoint_path = Path(checkpoint_info.checkpoint_path)
            if checkpoint_path.exists():
                try:
                    shutil.rmtree(checkpoint_path)
                    logger.info(f"Removed old checkpoint: {checkpoint_path}")
                except Exception as e:
                    logger.warning(f"Failed to remove checkpoint {checkpoint_path}: {e}")
        
        # Update checkpoint list
        self.checkpoints = self.checkpoints[-self.save_total_limit:]
    
    def _load_checkpoint_info(self) -> None:
        """Load checkpoint information from file."""
        if not self.checkpoint_info_file.exists():
            return
        
        try:
            with open(self.checkpoint_info_file, 'r') as f:
                data = json.load(f)
            
            self.checkpoints = [CheckpointInfo(**item) for item in data.get("checkpoints", [])]
            logger.info(f"Loaded {len(self.checkpoints)} checkpoint(s) info")
            
        except Exception as e:
            logger.warning(f"Failed to load checkpoint info: {e}")
            self.checkpoints = []
    
    def _save_checkpoint_info(self) -> None:
        """Save checkpoint information to file."""
        try:
            data = {
                "checkpoints": [checkpoint.__dict__ for checkpoint in self.checkpoints],
                "updated_at": datetime.now().isoformat()
            }
            
            with open(self.checkpoint_info_file, 'w') as f:
                json.dump(data, f, indent=2)
                
        except Exception as e:
            logger.warning(f"Failed to save checkpoint info: {e}")
    
    def get_latest_checkpoint(self) -> Optional[CheckpointInfo]:
        """Get latest checkpoint info."""
        if not self.checkpoints:
            return None
        
        return max(self.checkpoints, key=lambda x: x.step)
    
    def get_best_checkpoint(self) -> Optional[CheckpointInfo]:
        """Get checkpoint with lowest loss."""
        if not self.checkpoints:
            return None
        
        return min(self.checkpoints, key=lambda x: x.loss)


class MemoryManager:
    """Manages GPU memory and optimization."""
    
    def __init__(self, cleanup_frequency: int = 100):
        """
        Initialize memory manager.
        
        Args:
            cleanup_frequency: Steps between memory cleanup
        """
        self.cleanup_frequency = cleanup_frequency
        self.step_count = 0
        self.memory_history = []
    
    def track_memory(self, step: int) -> Dict[str, float]:
        """Track GPU memory usage."""
        if not torch.cuda.is_available():
            return {}
        
        memory_info = {}
        for i in range(torch.cuda.device_count()):
            device_name = f"cuda:{i}"
            memory_info[f"{device_name}_allocated"] = torch.cuda.memory_allocated(i) / 1e9
            memory_info[f"{device_name}_reserved"] = torch.cuda.memory_reserved(i) / 1e9
            memory_info[f"{device_name}_max_allocated"] = torch.cuda.max_memory_allocated(i) / 1e9
        
        self.memory_history.append({"step": step, "memory": memory_info})
        return memory_info
    
    def cleanup_memory(self, force: bool = False) -> None:
        """Clean up GPU memory."""
        self.step_count += 1
        
        if force or self.step_count % self.cleanup_frequency == 0:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            # Python garbage collection
            gc.collect()
            
            logger.debug(f"Memory cleanup performed at step {self.step_count}")
    
    def get_memory_summary(self) -> Dict[str, Any]:
        """Get memory usage summary."""
        if not self.memory_history:
            return {}
        
        latest = self.memory_history[-1]["memory"]
        
        summary = {
            "latest_memory_gb": latest,
            "peak_memory_gb": {},
            "average_memory_gb": {}
        }
        
        # Calculate peak and average memory
        if torch.cuda.is_available():
            for device_idx in range(torch.cuda.device_count()):
                device_name = f"cuda:{device_idx}"
                allocated_key = f"{device_name}_allocated"
                
                if allocated_key in latest:
                    allocated_values = [entry["memory"].get(allocated_key, 0) 
                                     for entry in self.memory_history]
                    summary["peak_memory_gb"][device_name] = max(allocated_values)
                    summary["average_memory_gb"][device_name] = sum(allocated_values) / len(allocated_values)
        
        return summary


class LearningRateScheduler:
    """Custom learning rate scheduler with warmup and decay."""
    
    def __init__(
        self,
        optimizer: Any,
        num_training_steps: int,
        num_warmup_steps: int = 0,
        scheduler_type: str = "cosine",
        num_cycles: int = 1,
        min_lr_ratio: float = 0.0
    ):
        """
        Initialize learning rate scheduler.
        
        Args:
            optimizer: PyTorch optimizer
            num_training_steps: Total training steps
            num_warmup_steps: Warmup steps
            scheduler_type: Type of scheduler
            num_cycles: Number of cycles for cyclic schedulers
            min_lr_ratio: Minimum LR as ratio of initial LR
        """
        self.optimizer = optimizer
        self.num_training_steps = num_training_steps
        self.num_warmup_steps = num_warmup_steps
        self.scheduler_type = scheduler_type
        self.num_cycles = num_cycles
        self.min_lr_ratio = min_lr_ratio
        
        # Get initial learning rate
        self.base_lr = optimizer.param_groups[0]['lr']
        self.min_lr = self.base_lr * min_lr_ratio
        
        # Create scheduler
        self.scheduler = get_scheduler(
            name=scheduler_type,
            optimizer=optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps,
            num_cycles=num_cycles
        )
    
    def step(self) -> float:
        """Step the scheduler and return current learning rate."""
        self.scheduler.step()
        return self.get_lr()
    
    def get_lr(self) -> float:
        """Get current learning rate."""
        return self.optimizer.param_groups[0]['lr']
    
    def get_schedule_info(self) -> Dict[str, Any]:
        """Get scheduler information."""
        return {
            "scheduler_type": self.scheduler_type,
            "base_lr": self.base_lr,
            "min_lr": self.min_lr,
            "num_training_steps": self.num_training_steps,
            "num_warmup_steps": self.num_warmup_steps,
            "current_lr": self.get_lr()
        }


class EarlyStoppingMonitor:
    """Monitors training for early stopping."""
    
    def __init__(
        self,
        patience: int = 10,
        min_delta: float = 1e-4,
        monitor: str = "loss",
        mode: str = "min"
    ):
        """
        Initialize early stopping monitor.
        
        Args:
            patience: Number of steps to wait before stopping
            min_delta: Minimum change to qualify as improvement
            monitor: Metric to monitor
            mode: "min" or "max" for improvement direction
        """
        self.patience = patience
        self.min_delta = min_delta
        self.monitor = monitor
        self.mode = mode
        
        self.best_score = None
        self.wait_count = 0
        self.stopped = False
        
        self.improvement_op = np.less if mode == "min" else np.greater
    
    def check_improvement(self, current_score: float) -> bool:
        """
        Check if current score represents improvement.
        
        Args:
            current_score: Current metric value
            
        Returns:
            True if training should stop
        """
        if self.best_score is None:
            self.best_score = current_score
            return False
        
        if self.improvement_op(current_score, self.best_score - self.min_delta):
            self.best_score = current_score
            self.wait_count = 0
            return False
        else:
            self.wait_count += 1
            
            if self.wait_count >= self.patience:
                self.stopped = True
                logger.info(f"Early stopping triggered. Best {self.monitor}: {self.best_score}")
                return True
            
            return False
    
    def get_status(self) -> Dict[str, Any]:
        """Get early stopping status."""
        return {
            "monitor": self.monitor,
            "mode": self.mode,
            "patience": self.patience,
            "wait_count": self.wait_count,
            "best_score": self.best_score,
            "stopped": self.stopped
        }


def setup_training_environment(config: Dict[str, Any]) -> None:
    """
    Set up training environment.
    
    Args:
        config: Training configuration
    """
    # Set random seeds
    seed = config.get("seed", 42)
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    
    # Set environment variables
    env_vars = config.get("environment", {}).get("env_vars", {})
    for key, value in env_vars.items():
        os.environ[key] = str(value)
    
    # CUDA optimizations
    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = config.get("training", {}).get("tf32", True)
        torch.backends.cudnn.benchmark = True
        
        # Memory settings
        memory_settings = config.get("environment", {}).get("memory_settings", {})
        if memory_settings.get("torch_cuda_empty_cache"):
            torch.cuda.empty_cache()


def estimate_training_time(
    num_samples: int,
    batch_size: int,
    num_epochs: int,
    samples_per_second: Optional[float] = None
) -> Dict[str, Any]:
    """
    Estimate training time.
    
    Args:
        num_samples: Number of training samples
        batch_size: Effective batch size
        num_epochs: Number of training epochs
        samples_per_second: Processing speed (if known)
        
    Returns:
        Training time estimates
    """
    steps_per_epoch = max(1, num_samples // batch_size)
    total_steps = steps_per_epoch * num_epochs
    total_samples = num_samples * num_epochs
    
    estimates = {
        "total_steps": total_steps,
        "steps_per_epoch": steps_per_epoch,
        "total_samples": total_samples,
        "num_epochs": num_epochs
    }
    
    if samples_per_second:
        total_time_seconds = total_samples / samples_per_second
        estimates.update({
            "estimated_total_time_hours": total_time_seconds / 3600,
            "estimated_time_per_epoch_hours": (total_time_seconds / num_epochs) / 3600,
            "samples_per_second": samples_per_second
        })
    
    return estimates


def create_training_summary(
    config: Dict[str, Any],
    metrics_history: List[TrainingMetrics],
    checkpoint_manager: CheckpointManager,
    memory_manager: MemoryManager
) -> Dict[str, Any]:
    """
    Create comprehensive training summary.
    
    Args:
        config: Training configuration
        metrics_history: Training metrics history
        checkpoint_manager: Checkpoint manager
        memory_manager: Memory manager
        
    Returns:
        Training summary
    """
    if not metrics_history:
        return {"error": "No training metrics available"}
    
    start_time = datetime.fromisoformat(metrics_history[0].timestamp)
    end_time = datetime.fromisoformat(metrics_history[-1].timestamp)
    duration = end_time - start_time
    
    # Calculate statistics
    losses = [m.loss for m in metrics_history]
    learning_rates = [m.learning_rate for m in metrics_history]
    
    summary = {
        "training_config": {
            "model": config.get("model", {}).get("vlm_model_name"),
            "training_method": config.get("training_method"),
            "framework": config.get("framework"),
            "batch_size": config.get("training", {}).get("per_device_train_batch_size"),
            "learning_rate": config.get("training", {}).get("learning_rate"),
            "num_epochs": config.get("training", {}).get("num_train_epochs")
        },
        
        "training_progress": {
            "total_steps": metrics_history[-1].step,
            "total_epochs": metrics_history[-1].epoch,
            "duration_hours": duration.total_seconds() / 3600,
            "start_time": start_time.isoformat(),
            "end_time": end_time.isoformat()
        },
        
        "loss_statistics": {
            "initial_loss": losses[0],
            "final_loss": losses[-1],
            "min_loss": min(losses),
            "max_loss": max(losses),
            "loss_reduction": losses[0] - losses[-1],
            "loss_reduction_percent": ((losses[0] - losses[-1]) / losses[0]) * 100
        },
        
        "learning_rate_info": {
            "initial_lr": learning_rates[0],
            "final_lr": learning_rates[-1],
            "min_lr": min(learning_rates),
            "max_lr": max(learning_rates)
        },
        
        "checkpoints": {
            "total_checkpoints": len(checkpoint_manager.checkpoints),
            "best_checkpoint": checkpoint_manager.get_best_checkpoint().__dict__ if checkpoint_manager.get_best_checkpoint() else None,
            "latest_checkpoint": checkpoint_manager.get_latest_checkpoint().__dict__ if checkpoint_manager.get_latest_checkpoint() else None
        },
        
        "memory_usage": memory_manager.get_memory_summary(),
        
        "performance": {
            "average_throughput": sum(m.throughput for m in metrics_history if m.throughput) / len([m for m in metrics_history if m.throughput]) if any(m.throughput for m in metrics_history) else None,
            "total_training_samples": len(metrics_history)
        }
    }
    
    return summary


def validate_training_setup(
    config: Dict[str, Any],
    train_dataset: Any,
    eval_dataset: Optional[Any] = None
) -> List[str]:
    """
    Validate training setup and return list of issues.
    
    Args:
        config: Training configuration
        train_dataset: Training dataset
        eval_dataset: Optional evaluation dataset
        
    Returns:
        List of validation issues
    """
    issues = []
    
    # Check dataset
    if len(train_dataset) == 0:
        issues.append("Training dataset is empty")
    
    # Check batch size vs dataset size
    batch_size = config.get("training", {}).get("per_device_train_batch_size", 1)
    if batch_size > len(train_dataset):
        issues.append(f"Batch size ({batch_size}) larger than dataset size ({len(train_dataset)})")
    
    # Check GPU memory
    if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            memory_gb = torch.cuda.get_device_properties(i).total_memory / 1e9
            if memory_gb < 12:
                issues.append(f"GPU {i} has low memory: {memory_gb:.1f}GB")
    else:
        issues.append("No CUDA GPUs available - training will be very slow")
    
    # Check learning rate
    lr = config.get("training", {}).get("learning_rate", 0)
    if lr > 1e-3:
        issues.append(f"Learning rate ({lr}) may be too high")
    elif lr < 1e-7:
        issues.append(f"Learning rate ({lr}) may be too low")
    
    # Check training method compatibility
    training_method = config.get("training_method")
    framework = config.get("framework")
    
    if training_method in ["ppo", "grpo"] and not config.get("reward_model"):
        issues.append("RL training methods require reward model configuration")
    
    return issues 