from typing import Any
from math import sqrt
import matplotlib.pyplot as plt


def subplots(
    nr: int = ...,
    nc: int = ...,
    *,
    n: int = ...,
    cellsize=(6, 4),
    **kwargs
) -> tuple[plt.Figure, Any]:
    if nr is ... or nc is ...:
        assert n is not ...
        if nr is ... and nc is ...:
            nc = n // int(sqrt(n))
            nr = n // nc
        
            if nr * nc < n:
                nr += 1
        elif nr is not ...:
            nc = n // nr

            if nr * nc < n:
                nc += 1
        else:
            nr = n // nc

            if nr * nc < n:
                nr += 1

    if n is ...:
        n = nr * nc

    fig, axes = plt.subplots(
        nr, nc,
        figsize=(cellsize[0] * nc, cellsize[1] * nr),
        **kwargs
    )

    # For any axes past n, drop axis
    for ax in axes.flatten()[n:]:
        ax.axis('off')
    
    return fig, axes
