import os
import torch
import numpy as np
import subprocess as sp
from typing import Tuple
import logging
import gc
from hydra.experimental.callback import Callback
from omegaconf import DictConfig

def byte_to_megabyte(value: int, digits: int = 2) -> float:
    return round(value / (1024 * 1024), digits)


def medibyte_to_megabyte(value: int, digits: int = 2) -> float:
    return round(1.0485 * value, digits)


def get_gpu_memory_from_nvidia_smi(device: int = 0) -> Tuple[int, int]:
    """
    Get GPU memory information from nvidia-smi.
    
    Returns:
        Tuple[int, int]: (free_memory_mb, used_memory_mb)
    """
    free, total = torch.cuda.mem_get_info(device)
    return free / 1024**2, (total - free) / 1024**2

    try:
        sp.check_output("nvidia-smi")
    except (FileNotFoundError, sp.CalledProcessError):
        # nvidia-smi not available
        return 0, 0
    
    try:
        _output_to_list = lambda x: x.decode("ascii").split("\n")[:-1]
        
        memory_free_info = _output_to_list(
            sp.check_output(
                [
                    "nvidia-smi",
                    "--query-gpu=memory.free",
                    "--format=csv,nounits,noheader",
                    f"--id={device}",
                ]
            )
        )[0]
        memory_used_info = _output_to_list(
            sp.check_output(
                [
                    "nvidia-smi",
                    "--query-gpu=memory.used",
                    "--format=csv,nounits,noheader",
                    f"--id={device}",
                ]
            )
        )[0]
        
        return int(memory_free_info), int(memory_used_info)
    except (ValueError, IndexError, sp.CalledProcessError):
        # Error parsing nvidia-smi output
        return 0, 0


def get_gpu_memory() -> Tuple[int, int]:
    """
    Get GPU memory information in megabytes.
    
    Returns:
        Tuple[int, int]: (free_memory_mb, used_memory_mb)
    """
    try:
        sp.check_output("nvidia-smi")
    except (FileNotFoundError, sp.CalledProcessError):
        # nvidia-smi not available, return 0s
        return 0, 0
    
    try:
        _output_to_list = lambda x: x.decode("ascii").split("\n")[:-1]
        
        memory_free_info = _output_to_list(
            sp.check_output(
                [
                    "nvidia-smi",
                    "--query-gpu=memory.free",
                    "--format=csv,nounits,noheader",
                ]
            )
        )[0]
        memory_used_info = _output_to_list(
            sp.check_output(
                [
                    "nvidia-smi",
                    "--query-gpu=memory.used",
                    "--format=csv,nounits,noheader",
                ]
            )
        )[0]
        
        return int(memory_free_info), int(memory_used_info)
    except (ValueError, IndexError, sp.CalledProcessError):
        # Error parsing nvidia-smi output, return 0s
        return 0, 0


def find_devices(max_devices: int = 1, greedy: bool = False, gamma: int = 12):
    # if no gpus are available return None
    if not torch.cuda.is_available():
        return max_devices
    n_gpus = torch.cuda.device_count()
    # if only 1 gpu, return 1 (i.e., the number of devices)
    if n_gpus == 1:
        return 1
    # if multiple gpus are available, return gpu id list with length max_devices
    visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES')
    if visible_devices is not None:
        visible_devices = [int(i) for i in visible_devices.split(',')]
    else:
        visible_devices = range(n_gpus)

    available_memory = np.asarray([get_gpu_memory_from_nvidia_smi(device)[0]
                                   for device in visible_devices])
    # if greedy, return `max_devices` gpus sorted by available capacity
    if greedy:
        idx_to_sort = np.argsort(available_memory)[::-1].tolist()
        idx_devices = idx_to_sort[:max_devices]
    else:
        # otherwise sample `max_devices` gpus according to available capacity
        p = (available_memory / np.linalg.norm(available_memory, gamma)) ** gamma
        # ensure p sums to 1
        p = p / p.sum()
        idx_devices = np.random.choice(np.arange(len(p)), size=max_devices, replace=False, p=p)

    return [visible_devices[i] for i in idx_devices]

def log_gpu_memory_usage(stage: str = "", device: int = 0):
    """Log current GPU memory usage for debugging memory leaks."""
    if not torch.cuda.is_available():
        return
    
    try:
        # PyTorch memory stats
        allocated = torch.cuda.memory_allocated(device) / (1024**3)  # GB
        cached = torch.cuda.memory_reserved(device) / (1024**3)  # GB
        
        # nvidia-smi stats
        try:
            free_mem, used_mem = get_gpu_memory_from_nvidia_smi(device)
            free_gb = free_mem / 1024
            used_gb = used_mem / 1024
            print(f"🔍 GPU Memory {stage}: PyTorch allocated={allocated:.2f}GB, cached={cached:.2f}GB | nvidia-smi free={free_gb:.2f}GB, used={used_gb:.2f}GB")
        except:
            print(f"🔍 GPU Memory {stage}: PyTorch allocated={allocated:.2f}GB, cached={cached:.2f}GB")
            
    except Exception as e:
        print(f"⚠️  Could not log GPU memory: {e}")


class FilterCallback(logging.Filterer):
    def filter(self, record: logging.LogRecord):
        return not (
            record.name == "neptune"
            and (
                record.getMessage().startswith(
                    "Error occurred during asynchronous operation processing: X-coordinates (step) must be strictly increasing for series attribute"
                )
                or record.getMessage().startswith(
                    "Error occurred during asynchronous operation processing: Timestamp must be non-decreasing for series attribute"
                )
            )
        )


class CUDACleanupCallback(Callback):
    """
    Hydra callback that clears CUDA cache and performs garbage collection at the end of each job.
    This helps prevent CUDA out of memory errors in sweep/multirun scenarios.
    """
    
    def on_job_end(self, config: DictConfig, **kwargs) -> None:
        """
        Clears CUDA cache and performs garbage collection at the end of each job.
        """
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            print("🧹 CUDA cache cleared.")
        gc.collect()
        print("🗑️  Garbage collection performed.")
        
        # Log final memory state for debugging
        log_gpu_memory_usage("POST_CLEANUP")
