import logging
import math

import torch
import torchvision.utils as vutils
import wandb
from torchvision.utils import make_grid


def create_logger(logging_dir):
    """Create a logger that writes to a log file and stdout."""
    logging.basicConfig(
        level=logging.INFO,
        format="[\x1b[34m%(asctime)s\x1b[0m] %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler(f"{logging_dir}/log.txt"),
        ],
    )
    logger = logging.getLogger(__name__)
    return logger


def array2grid(x):
    """Convert a batch of images to a grid for visualization."""
    nrow = round(math.sqrt(x.size(0)))
    x = make_grid(x.clamp(0, 1), nrow=nrow, value_range=(0, 1))
    x = (
        x.mul(255)
        .add_(0.5)
        .clamp_(0, 255)
        .permute(1, 2, 0)
        .to("cpu", torch.uint8)
        .numpy()
    )
    return x


def grid_image(batch, *, nrow=8, caption=None, normalize=True):
    """
    Turn a (B,C,H,W) tensor into one tiled image and wrap it in wandb.Image.

    Args
    ----
    batch      : torch.Tensor  (B,C,H,W)  — your images
    nrow       : int           number of imgs per row in the grid
    caption    : str | None    caption shown under the image
    normalize  : bool          scale float tensors to [0,1] automatically
    """
    # (C,H,W) after make_grid
    grid = vutils.make_grid(
        batch,
        nrow=nrow,
        padding=2,  # thin border between tiles
        normalize=normalize,  # rescales if dtype is float
    )
    return wandb.Image(grid, caption=caption)
