
import torch
import psutil
import os
import time
import threading
from typing import Optional

class ResourceMonitor:
    def __init__(self, device="cuda", sampling_interval=0.1):
        self.device = device
        self.sampling_interval = sampling_interval  
        
        self.start_gpu_mem = None
        self.peak_gpu_mem = 0
        
        self.peak_ram = 0
        self.monitoring_thread = None
        self.stop_monitoring = False
        
        self.start_time = None
        self.end_time = None
        self.wall_clock_time = 0
        
    def _monitor_ram(self):
        while not self.stop_monitoring:
            try:
                current_ram = self.process.memory_info().rss
                self.peak_ram = max(self.peak_ram, current_ram)
                
                time.sleep(self.sampling_interval)
            except Exception as e:
                print(f"RAM monitoring error: {e}")
                break
        
    def __enter__(self):
        torch.cuda.reset_peak_memory_stats(self.device)
        self.start_gpu_mem = torch.cuda.memory_allocated(self.device)
        
        self.process = psutil.Process(os.getpid())
        
        self.peak_ram = self.process.memory_info().rss
        self.stop_monitoring = False
        
        self.monitoring_thread = threading.Thread(target=self._monitor_ram, daemon=True)
        self.monitoring_thread.start()
        
        self.start_time = time.time()
        
        return self
        
    def __exit__(self, *args):
        self.stop_monitoring = True
        if self.monitoring_thread and self.monitoring_thread.is_alive():
            self.monitoring_thread.join(timeout=1.0)

        self.peak_gpu_mem = torch.cuda.max_memory_allocated(self.device) - self.start_gpu_mem
        
        self.end_time = time.time()
        self.wall_clock_time = self.end_time - self.start_time
    
    def get_summary(self):
        return {
            'gpu_peak_gb': self.peak_gpu_mem / 1e9,
            'ram_peak_gb': self.peak_ram / 1e9,
            'wall_clock_time_seconds': self.wall_clock_time,
            'wall_clock_time_minutes': self.wall_clock_time / 60,
            'wall_clock_time_hours': self.wall_clock_time / 3600
        }
    
    def print_summary(self):
        summary = self.get_summary()
        print(f"=== Resource Usage Summary ===")
        print(f"GPU Peak Memory: {summary['gpu_peak_gb']:.2f} GB")
        print(f"RAM Peak Memory: {summary['ram_peak_gb']:.2f} GB")
        print(f"Wall-clock Time: {summary['wall_clock_time_seconds']:.2f} seconds")
        print(f"Wall-clock Time: {summary['wall_clock_time_minutes']:.2f} minutes")
        print(f"Wall-clock Time: {summary['wall_clock_time_hours']:.2f} hours")
        print(f"===============================")

if __name__ == "__main__":
    monitor = ResourceMonitor(sampling_interval=0.1) 
    
    with monitor:
        print("Starting training simulation...")
        time.sleep(2)
        dummy_tensor = torch.randn(10000, 10000).cuda()
        time.sleep(1)
        del dummy_tensor
    
    monitor.print_summary()

    summary = monitor.get_summary()
    print(f"GPU Memory: {summary['gpu_peak_gb']:.2f} GB")
    print(f"Training Time: {summary['wall_clock_time_minutes']:.2f} minutes") 