import CRPS.CRPS as pscore
import numpy as np
import torch
import xarray as xr
from skimage.transform import resize


def get_crps(gt, samples):
    gt = gt.flatten()
    samples = samples.flatten(start_dim=1)
    crps_sum = 0
    fcrps_sum = 0
    acrps_sum = 0
    for i in range(len(gt)):
        crps, fcrps, acrps = pscore(samples[:, i].detach().cpu(), gt[i]).compute()
        crps_sum += crps
        fcrps_sum += fcrps
        acrps_sum += acrps
    return {
        "crps": crps_sum / len(gt),
        "fcrps": fcrps_sum / len(gt),
        "acrps": acrps_sum / len(gt),
    }


def get_ood_measurement(R, idx=0, noise_std=0.01):
    # Open dataset
    dset = xr.open_dataset(
        "/XXXX-2/XXXX-1/scratch/uq_diffusion/era5/data/era5_2m_temperature_2009-2017_01.grib"
    )
    dset = dset["t2m"].values

    # Normalize dataset
    dset = (dset - dset.min()) / (dset.max() - dset.min())
    y = dset[idx]
    y = resize(y, (R, R), mode="reflect", anti_aliasing=True, preserve_range=True)

    # Mask the image
    # mask = create_circular_mask(R, R, radius=4, center=(90, 100))
    mask = create_circular_mask(R, R, radius=4, center=(100, 20))
    y[mask] = 1

    # Add noise
    y += np.random.randn(*y.shape) * noise_std
    y = (y - y.min()) / (y.max() - y.min())
    y = y * 2 - 1
    y = torch.from_numpy(y).float()
    return y


def create_circular_mask(h, w, center=None, radius=None):
    if center is None:  # use the middle of the image
        center = (int(w / 2), int(h / 2))
    if radius is None:  # use the smallest distance between the center and image walls
        radius = min(center[0], center[1], w - center[0], h - center[1])

    Y, X = np.ogrid[:h, :w]
    dist_from_center = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)

    mask = dist_from_center <= radius
    return mask


def PSNR(x, y):
    mse = np.mean((x - y) ** 2)
    return 20 * np.log10(1.0 / np.sqrt(mse))
