# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library of statistical functions."""

import numpy as np
from scipy import stats


def sample_clipped_truncated_normal(left, right, mean,
                                    std,
                                    rng):
  """Returns a number sampled from a clipped truncated normal distribution.

  This is a convenience wrapper around scipy.stats.truncnorm. For details, see:
  https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.truncnorm.html

  Args:
    left: Leftmost value of the range to clip to.
    right: Rightmost value of the range to clip to.
    mean: Mean of the original normal distribution.
    std: Standard deviation of the original normal distribution.
    rng: Random number generator.
  """
  if std == 0:
    return mean
  # The clip values in scipy.stats.truncnorm are for the standard normal, so
  # we convert them according to the range we want.
  a, b = (left - mean) / std, (right - mean) / std
  value = stats.truncnorm.rvs(a, b, loc=mean, scale=std, random_state=rng)
  return value
