import gc
import os

import torch as t


def get_tensors_on_gpu() -> list[t.Tensor]:
    tensors_on_gpu: list[t.Tensor] = []
    for obj in gc.get_objects():
        try:
            if isinstance(obj, t.Tensor) and obj.is_cuda:
                tensors_on_gpu.append(obj)
        except Exception as e:
            pass
    return tensors_on_gpu


def log_largest_tensors(num_tensors: int = 30):
    tensors_on_gpu = get_tensors_on_gpu()
    sorted_tensors = sorted(
        tensors_on_gpu, key=lambda x: x.element_size() * x.nelement(), reverse=True
    )
    filtered_tensors = sorted_tensors[:num_tensors]

    print(f"{'Size (bytes)':>15} {'Shape':>20} {'Data Type':>15}")
    print("=" * 55)
    for tensor in filtered_tensors:
        size_in_bytes = tensor.element_size() * tensor.nelement()
        print(f"{size_in_bytes:>15} {str(tensor.shape):>20} {str(tensor.dtype):>15}")


def start_interactive_debugger():
    doc = os.path.basename(__file__)

    namespace = globals().copy()
    namespace.update(locals().copy())

    try:
        from IPython import embed

        # Starts an IPython shell with the filename as the header
        embed(
            header=f"Interactive session started from: {doc}",
            local=namespace,
            using="sync",
        )
    except ImportError:
        import code

        # Fallback to the standard Python interactive interpreter, pass the filename as banner
        code.interact(banner=f"Interactive session started from: {doc}", local=namespace)


# Print the tensors sorted by size
if __name__ == "__main__":
    log_largest_tensors()
