import torch
from typing import Optional

def get_device(gpu_id: Optional[str] = None) -> torch.device:
    if not torch.cuda.is_available():
        return torch.device('cpu')
    
    if gpu_id is not None:
        if isinstance(gpu_id, int):
            return torch.device(f'cuda:{gpu_id}')
        return torch.device(gpu_id)
    
    # If no GPU specified, use the GPU with most free memory
    if torch.cuda.device_count() > 1:
        free_mem = []
        for i in range(torch.cuda.device_count()):
            torch.cuda.set_device(i)
            free_mem.append(torch.cuda.get_device_properties(i).total_memory - torch.cuda.memory_allocated(i))
        device = torch.device(f'cuda:{free_mem.index(max(free_mem))}')
    else:
        device = torch.device('cuda:0')
    
    return device