import numpy as np
from scipy.stats import norm, poisson, gamma


def sim_rf(settg, law="norm", approx=0, seed=None):
    # extract settings
    LD = settg['LD']

    # simulate lattice noise
    if seed is not None:
        np.random.seed(seed)

    if law == "norm":
        noise = norm.rvs(scale=np.sqrt(settg['lambda'] / LD), size=(LD, LD))  # Gaussian
    elif law == "pois":
        noise = poisson.rvs(mu=settg['lambda'] / LD, size=(LD, LD))  # Poisson
    elif law == "gamma":
        noise = gamma.rvs(a=settg['lambda'] / LD, size=(LD, LD))  # Gamma
    elif law == "bigamma":
        noise = (gamma.rvs(a=settg['lambda'] / LD, size=(LD, LD)) -
                 gamma.rvs(a=settg['lambda'] / LD, size=(LD, LD))) / np.sqrt(2)  # Variance Gamma
    else:
        return "Unknown law of noise - no output"

    # Generate
    G_hat_mat = matern_rf(settg=settg, approx=approx)

    # simulate random field
    noise_hat = np.fft.fft2(noise)
    phi_hat = noise_hat * G_hat_mat
    phi = np.fft.ifft2(phi_hat).real / LD ** 2

    return {'field': phi, 'modes': phi_hat, 'settings': settg}


def matern_rf(settg, approx=0):
    # extract settings
    LD = settg['LD']
    lambda_val = settg['lambda']
    m = settg['m']
    nu = settg['nu']

    def G_hat(k):
        return m ** (2 * nu) / (2 * (2 - np.cos(k[0]) - np.cos(k[1])) + m ** 2) ** nu

    # grid in fourier space
    k1, k2 = np.meshgrid(np.arange(LD, dtype=float), np.arange(LD, dtype=float))
    k1 *= 2 * np.pi / LD
    k2 *= 2 * np.pi / LD
    k_Grid = np.dstack((k1, k2))

    G_hat_Mat = np.apply_along_axis(G_hat, 2, k_Grid)

    if approx:
        mask = mask_rf(LD, approx)
        G_hat_Mat *= mask

    return G_hat_Mat


def mask_rf(LD, approx):
    mask = np.ones((LD, LD))
    if approx:
        mask[(np.arange(LD) >= approx) & (np.arange(LD) <= LD - approx), :] = 0
        mask[:, (np.arange(LD) >= approx) & (np.arange(LD) <= LD - approx)] = 0
    return mask
