import numpy as np
from operators import R_opt, Rearrange_Covmat
from functions import generate_Cov


def score_of_gaussian(x, mu, scale):
    return (x - mu) / scale ** 2


def score_of_multivariate_normal(x, mu, Theta):
    if x.ndim == 1:
        score = Theta.dot(x - mu)
    else:
        score = (x - mu).dot(Theta)
    return score


def score_of_t(x, df):
    return (df + 1) * x / (df + x ** 2)


def score_of_gamma(x, gamma_shape, scale):
    return 1 / scale - (gamma_shape - 1) / x


def score_of_rayleigh(x, scale):
    return x / scale ** 2 - 1 / x


def get_params(dist, shape=(28, 28), kernel_size=(4, 4), stride=(4, 4), seed=0):
    rng = np.random.RandomState(seed)
    if dist == 'IID Gaussian':
        params = {'mu': 0, 'scale': 1}
        params_for_score = params
        print(f'Distribution: {dist} | mean: mu = {params["mu"]:.1f}, standard deviation: sigma = {params["scale"]:.1f}')
    elif dist == 't':
        params = {'df': 5}
        params_for_score = params
        print(f'Distribution: {dist} | degree of freedom: df = {params["df"]:.0f}')
    elif dist == 'Gamma':
        params = {'gamma_shape': 5, 'scale': 1}
        params_for_score = params
        print(f'Distribution: {dist} | shape: k = {params["gamma_shape"]:.1f}, scale: theta = {params["scale"]:.1f}')
    elif dist == 'Correlated Gaussian':
        params = {'mu': rng.randint(-5, 5, size=shape), 'Sigma': generate_Cov(np.prod(shape), mode='exp_decay', rho=0.5)}
        mu, Sigma = params['mu'], params['Sigma']
        Rmu = R_opt(mu, kernel_size, stride)
        RSigma = Rearrange_Covmat(Sigma, shape, kernel_size, stride)
        RTheta = np.linalg.inv(RSigma)
        params_for_score = {'mu': Rmu.ravel(), 'Theta': RTheta}
        print(f"Distribution: {dist} | mean: mu, covariance matrix: Sigma in 'params'.")
    else:
        raise ValueError('Undefined distribution')

    return params, params_for_score


def generate(n, shape, dist='gaussian', params={}, seed=0):
    rng = np.random.RandomState(seed)

    if dist == 'IID Gaussian':
        mu, scale = params.values()
        X = rng.normal(mu, scale, size=(n, *shape))
        score_function = score_of_gaussian
    elif dist == 'Correlated Gaussian':
        mu, Sigma = params.values()
        X = rng.multivariate_normal(mu.ravel(), Sigma, size=n).reshape((n, *shape))
        score_function = score_of_multivariate_normal
    elif dist == 't':
        df = params['df']
        X = rng.standard_t(df=df, size=(n, *shape))
        score_function = score_of_t
    elif dist == 'Gamma':
        gamma_shape, scale = params.values()
        X = rng.gamma(shape=gamma_shape, scale=scale, size=(n, *shape))
        score_function = score_of_gamma
    elif dist == 'Rayleigh':
        scale = params['scale']
        X = rng.rayleigh(scale=scale, size=(n, *shape))
        score_function = score_of_rayleigh
    else:
        raise ValueError("Unrecognized distribution, distributions should be: 'IID Gaussian', 'Correlated Gaussian', 't', 'Gamma' and 'Rayleigh'.")

    return X, score_function

