import torch

def standard_normal_pdf(x):
    # standard normal probability density function.
    return (1 / torch.sqrt(torch.tensor(2 * torch.pi, device=x.device))) * torch.exp(-0.5 * x ** 2)


def standard_normal_quantile(m, s, confidence=0.9):
    # truncated normal quantile function
    standard_normal = torch.distributions.Normal(m, s)

    # cdf_min = standard_normal.cdf(torch.tensor(0))
    # cdf_max = standard_normal.cdf(torch.tensor(1))

    quantiles = standard_normal.icdf(torch.tensor(confidence, device=m.device))
    # truncated_quantile = cdf_min + confidence * (cdf_max - cdf_min)
    # truncated_icdf_value = standard_normal.icdf(truncated_quantile)
    # truncated_quantiles = torch.clamp(quantiles, min=0.0, max=1.0)

    return quantiles


def standard_normal_cdf(x):
    # standard normal cumulative distribution function
    standard_normal = torch.distributions.Normal(0, 1)
    return standard_normal.cdf(x)


def value_at_rask(mu, sigma, confidence=0.9):
    """

    Args:
        mu: Mean value: (B, 1)
        sigma: Var: (B, 1)
        matching: if 1: mismatching; if 0: matching : B
        confidence: default: 0.9

    Returns: VaR: (B, 1)

    """
    var = []

    # one = torch.tensor(1., device=mu.device)
    # zero = torch.tensor(0., device=mu.device)
    for index, (m, s) in enumerate(zip(mu, sigma)):
        # if ma == 1:
        quantile = standard_normal_quantile(m, s, confidence)
            # print('1:', quantile)
        # else:
        #     quantile = standard_normal_quantile(m, s, 1-confidence)
            # print('2:', quantile)
        var.append(quantile)

    return torch.stack(var, dim=0)


def var2pro(mu_bar, sigma):
    return torch.sigmoid(standard_normal_quantile(mu_bar, sigma))

def conditional_value_at_rask(mu, sigma, matching, confidence=0.9):
    """

    Args:
        mu: Mean value: (B, 1)
        sigma: Var: (B, 1)
        matching: if 1: mismatching; if 0: matching : B
        confidence: default: 0.9

    Returns: CVaR: (B, 1)

    """
    cvar = []
    one = torch.tensor(1., device=mu.device)
    zero = torch.tensor(0., device=mu.device)
    for index, (m, s, ma) in enumerate(zip(mu, sigma, matching)):
        if ma == 1:

            quantile = m + s * standard_normal_quantile(zero, one, confidence)
            factor = standard_normal_pdf(one)-standard_normal_pdf(quantile)
            factor = factor / (standard_normal_cdf(one) - standard_normal_cdf(quantile))
            cvar.append(m-s*factor)
        else:

            quantile = 1 - m - s * standard_normal_quantile(zero, one, 1-confidence)
            factor = standard_normal_pdf(quantile) - standard_normal_pdf(zero)
            factor = factor / (standard_normal_cdf(quantile) - standard_normal_cdf(zero))
            cvar.append(1 - m + s * factor)

    return torch.stack(cvar, dim=0)

