import torch
import os

def get_used_gpus():
    """Retrieve GPU information utilized by the current process."""
    if not torch.cuda.is_available():
        print("CUDA is not available, no GPUs are being used.")
        return []
    
    # Get GPUs visible to the current process
    visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES', '')
    print('visible_devices', visible_devices)
    if visible_devices:
        visible_indices = [int(x) for x in visible_devices.split(',')]
    else:
        visible_indices = list(range(torch.cuda.device_count()))
    
    # Check which GPUs have memory allocated
    used_gpus = []
    for i in visible_indices:
        memory_allocated = torch.cuda.memory_allocated(i)
        if memory_allocated > 0:
            print('> 0')
            props = torch.cuda.get_device_properties(i)
            used_gpus.append({
                'index': i,
                'name': props.name,
                'memory_allocated': memory_allocated / 1024**2,  # MB
                'total_memory': props.total_memory / 1024**3,    # GB
            })
    
    return used_gpus

def print_used_gpus():
    """Print GPU information used by the current training script."""
    used_gpus = get_used_gpus()
    
    if not used_gpus:
        print("No GPUs are being used by the current training script.")
        return
    
    print("=" * 60)
    print("GPU information used by the current training script:")
    print("=" * 60)
    
    for gpu in used_gpus:
        print(f"GPU {gpu['index']}: {gpu['name']}")
        print(f"  Memory allocated: {gpu['memory_allocated']:.2f} MB")
        print(f"  Total memory: {gpu['total_memory']:.2f} GB")
        print("-" * 40)
    
    print(f"Total GPUs used: {len(used_gpus)}")
def check_training_gpu_usage(net=None, data_loader=None):
    """Check GPU usage related to training."""
    print("=" * 60)
    print("Training configuration check:")
    print("=" * 60)
    
    # Check if the model is on GPU
    if net is not None:
        model_device = next(net.parameters()).device
        print(f"Model device: {model_device}")
        
        if str(model_device) == 'cpu':
            print("⚠️  Warning: Model is on CPU, not on GPU!")
        else:
            print("✅ Model is on GPU")
    
    # Check data loader
    if data_loader is not None:
        try:
            sample_batch = next(iter(data_loader))
            if isinstance(sample_batch, (list, tuple)):
                data_device = sample_batch[0].device
                print(f"Data device: {data_device}")
                
                if str(data_device) == 'cpu':
                    print("⚠️  Warning: Data is on CPU, needs to be moved to GPU!")
                else:
                    print("✅ Data is on GPU")
        except:
            print("Unable to check data device")
    
    # Check multi-GPU configuration
    if net is not None and isinstance(net, torch.nn.DataParallel):
        print(f"✅ Using DataParallel, distributed across {len(net.device_ids)} GPUs")
        print(f"   GPU device IDs used: {net.device_ids}")
    elif torch.cuda.device_count() > 1:
        print(f"⚠️  Detected {torch.cuda.device_count()} GPUs, but DataParallel is not used")
    
    print_used_gpus()

# In your training code, call this at the appropriate place
if __name__ == "__main__":
    
    # Check GPU usage here!
    print("\n" + "=" * 60)
    print("Pre-training GPU usage check:")
    print("=" * 60)
    check_training_gpu_usage(net, trainloader)
    
    # Test initial accuracy
    test_acc, predicted = test(args, net, testloader, device, 0)
    print("scratch prediction ", test_acc)
    
    # Check GPU usage after testing
    print("\n" + "=" * 60)
    print("Post-testing GPU usage check:")
    print("=" * 60)
    check_training_gpu_usage(net, trainloader)

    # The rest of the training code...