"""
Memory management utilities for handling large-scale distance computations.
"""

import numpy as np
import psutil
import torch
from typing import Optional, Tuple
from loguru import logger


def estimate_matrix_memory_gb(n: int, m: int = None, dtype=np.float32) -> float:
    """
    Estimate memory required for an n×m matrix in GB.

    Args:
        n: Number of rows
        m: Number of columns (default: n for square matrix)
        dtype: Data type (default: float32)

    Returns:
        Memory in GB
    """
    if m is None:
        m = n

    bytes_per_element = np.dtype(dtype).itemsize
    total_bytes = n * m * bytes_per_element
    return total_bytes / (1024**3)


def get_available_memory_gb(use_gpu: bool = False, device: Optional[torch.device] = None) -> float:
    """
    Get available memory in GB.

    Args:
        use_gpu: Whether to check GPU memory
        device: GPU device to check (if use_gpu=True)

    Returns:
        Available memory in GB
    """
    if use_gpu and torch.cuda.is_available():
        try:
            # Handle device specification
            if device is None:
                device_id = 0
            elif device.type == 'cuda':
                # If device has an index, use it; otherwise default to 0
                device_id = device.index if device.index is not None else 0
            else:
                # Not a CUDA device, fall back to CPU
                use_gpu = False

            if use_gpu:
                torch.cuda.synchronize(device_id)
                free_memory, total_memory = torch.cuda.mem_get_info(device_id)
                return free_memory / (1024**3)
        except Exception as e:
            logger.warning(f"Failed to get GPU memory: {e}, falling back to CPU memory")
            use_gpu = False

    if not use_gpu:
        # Get available RAM
        memory = psutil.virtual_memory()
        return memory.available / (1024**3)


def compute_optimal_chunk_size(
    n: int,
    m: int = None,
    available_memory_gb: Optional[float] = None,
    safety_factor: float = 0.3,
    min_chunk_size: int = 100,
    max_chunk_size: int = 10000,
    use_gpu: bool = False,
    device: Optional[torch.device] = None
) -> int:
    """
    Compute optimal chunk size for distance matrix computation.

    Args:
        n: Number of rows in the full matrix
        m: Number of columns (default: n for square matrix)
        available_memory_gb: Available memory in GB (auto-detected if None)
        safety_factor: Fraction of available memory to use (0.3 = 30%)
        min_chunk_size: Minimum chunk size
        max_chunk_size: Maximum chunk size
        use_gpu: Whether using GPU
        device: GPU device

    Returns:
        Optimal chunk size
    """
    if m is None:
        m = n

    # Get available memory
    if available_memory_gb is None:
        available_memory_gb = get_available_memory_gb(use_gpu=use_gpu, device=device)

    # Account for safety factor
    usable_memory_gb = available_memory_gb * safety_factor

    # Estimate memory per chunk (chunk_size × m matrix)
    bytes_per_element = 4  # float32
    memory_per_row_gb = m * bytes_per_element / (1024**3)

    # Calculate how many rows fit in memory
    chunk_size = int(usable_memory_gb / memory_per_row_gb)

    # Clamp to min/max
    chunk_size = max(min_chunk_size, min(chunk_size, max_chunk_size, n))

    return chunk_size


def check_memory_sufficient(
    n: int,
    m: int = None,
    required_memory_gb: Optional[float] = None,
    use_gpu: bool = False,
    device: Optional[torch.device] = None,
    safety_factor: float = 0.8
) -> Tuple[bool, str]:
    """
    Check if sufficient memory is available for operation.

    Args:
        n: Matrix dimension 1
        m: Matrix dimension 2 (default: n)
        required_memory_gb: Required memory (auto-estimated if None)
        use_gpu: Whether using GPU
        device: GPU device
        safety_factor: Require this fraction of memory to be available

    Returns:
        (is_sufficient, message)
    """
    if m is None:
        m = n

    # Estimate required memory
    if required_memory_gb is None:
        required_memory_gb = estimate_matrix_memory_gb(n, m)

    # Get available memory
    available_memory_gb = get_available_memory_gb(use_gpu=use_gpu, device=device)

    # Check if sufficient
    memory_type = "GPU" if (use_gpu and device is not None and device.type == 'cuda') else "RAM"

    if available_memory_gb * safety_factor < required_memory_gb:
        message = (
            f"Insufficient {memory_type}: need {required_memory_gb:.2f} GB, "
            f"have {available_memory_gb:.2f} GB available. "
            f"Consider using chunked computation or reducing dataset size."
        )
        return False, message
    else:
        message = (
            f"Sufficient {memory_type}: need {required_memory_gb:.2f} GB, "
            f"have {available_memory_gb:.2f} GB available."
        )
        return True, message


def warn_if_memory_insufficient(
    n: int,
    m: int = None,
    operation_name: str = "operation",
    use_gpu: bool = False,
    device: Optional[torch.device] = None,
    auto_chunk: bool = True
) -> bool:
    """
    Warn user if memory may be insufficient for operation.

    Args:
        n: Matrix dimension 1
        m: Matrix dimension 2 (default: n)
        operation_name: Name of operation for warning message
        use_gpu: Whether using GPU
        device: GPU device
        auto_chunk: Whether chunking is available as fallback

    Returns:
        True if memory is sufficient, False otherwise
    """
    is_sufficient, message = check_memory_sufficient(n, m, use_gpu=use_gpu, device=device)

    if not is_sufficient:
        logger.warning(f"{operation_name}: {message}")
        if auto_chunk:
            logger.debug(f"Automatically enabling chunked computation for {operation_name}")
        return False
    else:
        logger.debug(f"{operation_name}: {message}")
        return True


def log_memory_usage(stage: str, use_gpu: bool = False, device: Optional[torch.device] = None):
    """
    Log current memory usage.

    Args:
        stage: Description of current stage
        use_gpu: Whether to log GPU memory
        device: GPU device
    """
    if use_gpu and torch.cuda.is_available():
        try:
            # Handle device specification
            if device is None:
                device_id = 0
            elif device.type == 'cuda':
                device_id = device.index if device.index is not None else 0
            else:
                device_id = 0

            torch.cuda.synchronize(device_id)
            allocated = torch.cuda.memory_allocated(device_id) / (1024**3)
            reserved = torch.cuda.memory_reserved(device_id) / (1024**3)
            free, total = torch.cuda.mem_get_info(device_id)
            free_gb = free / (1024**3)
            total_gb = total / (1024**3)
            logger.debug(
                f"[{stage}] GPU Memory: {allocated:.2f} GB allocated, "
                f"{reserved:.2f} GB reserved, {free_gb:.2f} GB free / {total_gb:.2f} GB total"
            )
        except Exception as e:
            logger.warning(f"Failed to log GPU memory: {e}")
    else:
        memory = psutil.virtual_memory()
        used_gb = memory.used / (1024**3)
        available_gb = memory.available / (1024**3)
        total_gb = memory.total / (1024**3)
        logger.debug(
            f"[{stage}] RAM: {used_gb:.2f} GB used, "
            f"{available_gb:.2f} GB available / {total_gb:.2f} GB total"
        )
