import os
from typing import Callable, Optional

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision.utils import make_grid, save_image


def _ensure_dir(path: str) -> None:
    if path:
        os.makedirs(path, exist_ok=True)


def to_uint8_rgb(imgs: torch.Tensor, size: Optional[int]) -> torch.Tensor:
    """Map float tensors in [-1,1] to uint8 RGB and optionally resize."""
    if imgs.dim() != 4:
        raise ValueError("Expected BCHW tensor for image conversion")
    imgs = imgs.clamp_(-1.0, 1.0)
    imgs = (imgs + 1.0) / 2.0
    if imgs.shape[1] == 1:
        imgs = imgs.repeat(1, 3, 1, 1)
    if size is not None:
        imgs = F.interpolate(imgs, size=(size, size), mode="bilinear", align_corners=False)
    return (imgs * 255.0).round().clamp(0, 255).to(torch.uint8)


class Uint8Dataset(Dataset):
    def __init__(self, tensor_uint8: torch.Tensor) -> None:
        self.tensor = tensor_uint8

    def __len__(self) -> int:  # pragma: no cover - trivial
        return int(self.tensor.shape[0])

    def __getitem__(self, index: int) -> torch.Tensor:  # pragma: no cover - trivial
        return self.tensor[index]


def compute_fid(
    real: torch.Tensor,
    gen: torch.Tensor,
    *,
    device: torch.device,
    image_size: int,
    batch_size: int,
) -> float:
    try:
        import torch_fidelity
    except ImportError as exc:  # pragma: no cover - dependency guard
        raise ImportError("torch_fidelity is required for FID computation") from exc

    real_uint8 = to_uint8_rgb(real, image_size)
    gen_uint8 = to_uint8_rgb(gen, image_size)

    real_ds = Uint8Dataset(real_uint8.cpu())
    gen_ds = Uint8Dataset(gen_uint8.cpu())

    metrics = torch_fidelity.calculate_metrics(
        input1=real_ds,
        input2=gen_ds,
        fid=True,
        batch_size=batch_size,
        cuda=(device.type == "cuda"),
        verbose=False,
    )
    return float(metrics["frechet_inception_distance"])


def compute_cleanfid_from_generator(
    *,
    generator: Callable[[int], torch.Tensor],
    image_shape: torch.Size,
    fid_image_size: int,
    fid_batch_size: int,
    fid_num_gen: int,
    dataset_name: str,
    dataset_split: str = "train",
    mode: str = "legacy_tensorflow",
) -> float:
    """Compute FID via CleanFID by repeatedly sampling from `generator`."""
    try:
        from cleanfid import fid as clean_fid
    except ImportError as exc:  # pragma: no cover - dependency guard
        raise ImportError("cleanfid is required for CleanFID evaluation") from exc

    fid_batch = max(1, int(fid_batch_size))
    dataset_res = int(fid_image_size)

    # Use "clean" mode for CelebA to avoid downloading pre-computed statistics
    # which can fail with HTTPError. Keep legacy_tensorflow for other datasets.
    dataset_lower = str(dataset_name).lower()
    if "celeba" in dataset_lower and mode == "legacy_tensorflow":
        mode = "clean"

    def _gen(unused_latent):
        count = fid_batch
        if hasattr(unused_latent, "shape") and len(unused_latent.shape) > 0:
            try:
                count = int(unused_latent.shape[0])
            except (TypeError, ValueError):
                pass
        samples = generator(count)
        imgs = reshape_flat_samples(samples, image_shape)
        imgs_uint8 = to_uint8_rgb(imgs, dataset_res)
        return imgs_uint8.to(torch.float32)

    score = clean_fid.compute_fid(
        gen=_gen,
        dataset_name=dataset_lower,
        dataset_split=dataset_split,
        dataset_res=dataset_res,
        num_gen=int(fid_num_gen),
        batch_size=fid_batch,
        mode=mode,
    )
    return float(score)


def save_image_grid(
    images: torch.Tensor,
    *,
    path: str,
    nrow: int = 8,
) -> np.ndarray:
    """Save a grid of [-1,1]-scaled images and return an array for logging."""
    directory = os.path.dirname(path)
    _ensure_dir(directory)
    images = images.clamp(-1.0, 1.0)
    save_image(((images + 1.0) / 2.0).cpu(), path, nrow=nrow)
    grid = make_grid(images, nrow=nrow, normalize=True, value_range=(-1.0, 1.0))
    return grid.permute(1, 2, 0).cpu().numpy()


def reshape_flat_samples(samples: torch.Tensor, shape: torch.Size) -> torch.Tensor:
    if samples.dim() == 2:
        return samples.view(samples.shape[0], *shape)
    if samples.shape[1:] == shape:
        return samples
    raise ValueError("Unexpected sample shape for image reshape")
