import torch
import torch.nn.functional as F
import numpy as np


def gaussian_kernel(size, sigma):
    coords = torch.arange(size, dtype=torch.float32)
    coords -= size // 2

    g = torch.exp(-(coords**2) / (2 * sigma**2))
    g /= g.sum()

    return g


def create_window(window_size, channel, sigma=1.5):
    _1D_window = gaussian_kernel(window_size, sigma)

    if _1D_window.dim() == 0:
        _1D_window = _1D_window.unsqueeze(0)

    _2D_window = _1D_window.unsqueeze(1) * _1D_window.unsqueeze(0)

    window = _2D_window.unsqueeze(0).unsqueeze(0)

    window = window.expand(channel, 1, window_size, window_size).contiguous()

    return window


def calculate_ssim(
    img1, img2, window_size=11, size_average=True, val_range=1.0
):
    if len(img1.shape) == 3:
        img1 = img1.unsqueeze(0)
    if len(img2.shape) == 3:
        img2 = img2.unsqueeze(0)

    channel = img1.size(1)

    h, w = img1.shape[-2:]
    window_size = min(window_size, min(h, w))
    if window_size % 2 == 0:
        window_size -= 1
    if window_size < 3:
        window_size = 3

    window = create_window(window_size, channel)

    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)

    return _ssim(
        img1, img2, window, window_size, channel, size_average, val_range
    )


def _ssim(
    img1, img2, window, window_size, channel, size_average=True, val_range=1.0
):
    C1 = (0.01 * val_range) ** 2
    C2 = (0.03 * val_range) ** 2

    padding = window_size // 2

    mu1 = F.conv2d(img1, window, padding=padding, groups=channel)
    mu2 = F.conv2d(img2, window, padding=padding, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = (
        F.conv2d(img1 * img1, window, padding=padding, groups=channel) - mu1_sq
    )
    sigma2_sq = (
        F.conv2d(img2 * img2, window, padding=padding, groups=channel) - mu2_sq
    )
    sigma12 = (
        F.conv2d(img1 * img2, window, padding=padding, groups=channel)
        - mu1_mu2
    )

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
        (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
    )

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)


def calculate_psnr(img1, img2, max_val=1.0):
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0 or mse < 1e-10:
        return 100.0

    psnr_value = 20 * torch.log10(torch.tensor(max_val)) - 10 * torch.log10(
        mse
    )

    if torch.isnan(psnr_value) or torch.isinf(psnr_value):
        return 0.0

    return psnr_value.item()


def batch_psnr(img1, img2, max_val=1.0):
    if img1.shape != img2.shape:
        print(
            f"Warning: Shape mismatch in PSNR calculation: {img1.shape} vs {img2.shape}"
        )
        return 0.0

    batch_size = img1.size(0)
    psnr_values = []

    for i in range(batch_size):
        try:
            psnr = calculate_psnr(img1[i], img2[i], max_val)
            if not (
                torch.isnan(torch.tensor(psnr))
                or torch.isinf(torch.tensor(psnr))
            ):
                psnr_values.append(psnr)
        except Exception as e:
            print(f"Error calculating PSNR for sample {i}: {e}")
            continue

    return np.mean(psnr_values) if psnr_values else 0.0


def batch_ssim(img1, img2, window_size=11, val_range=1.0):
    if img1.shape != img2.shape:
        print(
            f"Warning: Shape mismatch in SSIM calculation: {img1.shape} vs {img2.shape}"
        )
        return 0.0

    try:
        ssim_value = calculate_ssim(
            img1, img2, window_size, size_average=True, val_range=val_range
        )

        if torch.isnan(ssim_value) or torch.isinf(ssim_value):
            return 0.0

        return ssim_value.item()
    except Exception as e:
        print(f"Error calculating SSIM: {e}")
        return 0.0
