import numpy as np
from scipy.stats import distributions


def kl_divergent_norm(
    rv: distributions.rv_frozen,
    kl_divergence: float,
) -> distributions.rv_frozen:
    new_mean = rv.mean() + np.sqrt(2 * rv.std() ** 2 * kl_divergence)

    return distributions.norm(loc=new_mean, scale=rv.std())
