from __future__ import annotations

from collections.abc import Iterable
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any

import matplotlib.pyplot as plt
import torch


@dataclass
class NoGradient:
    wrapped: Any

    def __getattr__(self, name: str) -> Any:
        with torch.no_grad():
            return getattr(self.wrapped, name)


def draw(*tensors: torch.Tensor | Iterable[torch.Tensor],
         show: bool = True) -> plt.Figure:
    if len(tensors) == 1:
        if not isinstance(tensors[0], torch.Tensor):
            tensors = tensors[0]

    figure, [axes] = plt.subplots(1, len(tensors), squeeze=False)
    for tensor, axis in zip(tensors, axes):
        axis.axis('off')
        axis.imshow(tensor.detach().cpu().permute(1, 2, 0))

    if show:
        plt.show()
        plt.close()
    return figure


@contextmanager
def frozen(model: torch.nn.Module):
    for param in model.parameters():
        param.requires_grad = False
    try:
        yield model
    finally:
        for param in model.parameters():
            param.requires_grad = True
