"""
Performance monitoring utilities for samplers.

Provides decorators and utilities to track runtime, CPU memory, and GPU memory usage.
"""

import functools
import logging
import time
from typing import Any, Callable, Dict, Optional

import torch

logger = logging.getLogger(__name__)

# Try to import psutil for better memory tracking
try:
    import psutil
    PSUTIL_AVAILABLE = True
except ImportError:
    PSUTIL_AVAILABLE = False
    logger.warning(
        "psutil not available. Memory monitoring will use basic tracking. "
        "Install with: pip install psutil"
    )


def get_memory_mb() -> float:
    """Get current process memory usage in MB."""
    if PSUTIL_AVAILABLE:
        process = psutil.Process()
        return process.memory_info().rss / (1024 * 1024)
    else:
        # Fallback: return 0 if psutil not available
        return 0.0


def get_gpu_memory_mb() -> Dict[str, float]:
    """Get GPU memory usage in MB (if CUDA available)."""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / (1024 * 1024)
        reserved = torch.cuda.memory_reserved() / (1024 * 1024)
        return {
            "allocated": allocated,
            "reserved": reserved,
        }
    return {}


def monitor_performance(sampler_type: str):
    """
    Decorator to monitor performance of sampler's sample() method.
    
    Tracks:
    - Runtime (seconds)
    - CPU memory delta and peak (MB)
    - GPU memory allocated and peak (MB) if CUDA available
    
    Logs metrics to tensorboard if writer is provided in kwargs.
    
    Args:
        sampler_type: Type identifier for logging (e.g., 'batch_sampler', 'negative_sampler')
    """
    def decorator(func: Callable) -> Callable:
        @functools.wraps(func)
        def wrapper(self, *args, **kwargs) -> Any:
            # Extract optional parameters for logging
            writer = kwargs.get("writer", None)
            epoch = kwargs.get("epoch", None)
            
            # Record initial state
            start_time = time.perf_counter()
            memory_start = get_memory_mb()
            
            # GPU memory tracking
            gpu_available = torch.cuda.is_available()
            if gpu_available:
                torch.cuda.reset_peak_memory_stats()
                gpu_mem_start = torch.cuda.memory_allocated() / (1024 * 1024)
            
            # Execute the actual sample method
            result = func(self, *args, **kwargs)
            
            # Record final state
            end_time = time.perf_counter()
            memory_end = get_memory_mb()
            
            # Compute metrics
            runtime = end_time - start_time
            memory_delta = memory_end - memory_start
            
            # Log basic metrics
            logger.info(
                f"[{sampler_type}] Performance | runtime={runtime:.3f}s "
                f"memory_delta={memory_delta:.2f}MB memory_final={memory_end:.2f}MB"
            )
            
            # GPU metrics
            gpu_metrics = {}
            if gpu_available:
                gpu_mem_end = torch.cuda.memory_allocated() / (1024 * 1024)
                gpu_mem_peak = torch.cuda.max_memory_allocated() / (1024 * 1024)
                gpu_mem_delta = gpu_mem_end - gpu_mem_start
                gpu_metrics = {
                    "delta": gpu_mem_delta,
                    "peak": gpu_mem_peak,
                    "final": gpu_mem_end,
                }
                logger.info(
                    f"[{sampler_type}] GPU Memory | "
                    f"delta={gpu_mem_delta:.2f}MB "
                    f"peak={gpu_mem_peak:.2f}MB "
                    f"final={gpu_mem_end:.2f}MB"
                )
            
            # Log to tensorboard if writer and epoch are provided
            if writer is not None and epoch is not None:
                writer.add_scalar(f"{sampler_type}/runtime_seconds", runtime, epoch)
                writer.add_scalar(f"{sampler_type}/memory_delta_mb", memory_delta, epoch)
                writer.add_scalar(f"{sampler_type}/memory_final_mb", memory_end, epoch)
                
                if gpu_available:
                    writer.add_scalar(
                        f"{sampler_type}/gpu_memory_delta_mb",
                        gpu_metrics["delta"],
                        epoch
                    )
                    writer.add_scalar(
                        f"{sampler_type}/gpu_memory_peak_mb",
                        gpu_metrics["peak"],
                        epoch
                    )
                    writer.add_scalar(
                        f"{sampler_type}/gpu_memory_final_mb",
                        gpu_metrics["final"],
                        epoch
                    )
                
                logger.info(
                    f"[{sampler_type}] Performance metrics logged to tensorboard for epoch {epoch}"
                )
            
            return result
        
        return wrapper
    return decorator
