import math

import matplotlib.pyplot as plt
import torch


def plot_2d(image: torch.Tensor, figname: str = "test.pdf"):
    """Plot 2d matrix"""
    image = (
        image.detach()
    )  # RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.
    shape = image.shape
    if len(shape) == 1:
        n = int(math.log2(shape[0]))
        s1 = 2 ** (n // 2)
        s2 = 2 ** (n - n // 2)
        image = image.view(s1, s2)
    elif len(shape) == 3:
        image = image.view(shape[1], shape[2])
    plt.imshow(image.cpu())
    plt.savefig(figname)
