import torch

def find_max_batch_size(model, device, dataset, input_index=0):
    ''' find the maximum batch size for inference on a given device without causing memory errors '''
    assert len(dataset) > 0, "Dataset must not be empty."

    model.to(device)

    # handle CPU case - use a safe limit of 4096
    if str(device) == 'cpu':
        return min(4096, len(dataset))
    
    # save original model state
    original_training = model.training
    model.eval()
    
    try:
        # test single sample first
        sample_input = dataset[0][input_index].unsqueeze(0).to(device)
        try:
            with torch.no_grad():
                model(sample_input)
        except RuntimeError as e:
            if 'out of memory' in str(e).lower():
                torch.cuda.empty_cache()
                return 0  # can't run even single sample
            raise

        max_batch_size = 1
        next_size = 2  # start doubling from 1
        
        # exponential search for maximum batch size
        while next_size <= len(dataset):
            try:
                # build batch efficiently using list comprehension
                inputs = torch.stack([dataset[i][input_index] for i in range(next_size)]).to(device)
                
                with torch.no_grad():
                    model(inputs)
                
                # success - update and continue doubling
                max_batch_size = next_size
                next_size *= 2
                
            except RuntimeError as e:
                if 'out of memory' in str(e).lower():
                    break  # stop at current max
                raise
            finally:
                # cleanup regardless of success
                if 'inputs' in locals():
                    del inputs
                torch.cuda.empty_cache()
                
        print("Maximum safe batch size for inference on device {} is: {}".format(device, max_batch_size))

        return max_batch_size
        
    finally:
        # restore original model state
        model.train(original_training)