import torch
from torch.nn import functional as F
import numpy as np
import math
import matplotlib.pyplot as plt
from didigress.utils import PlaceHolder, remove_mean_with_mask


def sum_except_batch(x):
    return x.reshape(x.size(0), -1).sum(dim=-1)


def assert_correctly_masked(variable, node_mask):
    assert not torch.isnan(variable).any(), f"Shape:{variable.shape}"
    assert (variable * (1 - node_mask.long())).abs().max().item() < 1e-4, \
        f'Variables not masked properly. {variable * (1 - node_mask.long())}'


def sample_gaussian_with_mask(size, node_mask):
    x = torch.randn(size).to(node_mask.device)
    x_masked = x * node_mask
    return x_masked

"""
No need to rework this method as it is only called in situations where
3D is enabled
"""
def remove_mean_with_mask(pos, node_mask):
    """ pos: bs x n x 3 (float32)
        node_mask: bs x n (bool)"""
    assert node_mask.dtype == torch.bool, f"Wrong dtype for the mask: {node_mask.dtype}"
    node_mask = node_mask.unsqueeze(-1)
    masked_max_abs_value = (pos * (~node_mask)).abs().sum().item()
    assert masked_max_abs_value < 1e-5, f'Error {masked_max_abs_value} too high'
    N = node_mask.sum(1, keepdims=True)

    mean = torch.sum(pos, dim=1, keepdim=True) / N
    pos = pos - mean * node_mask
    return pos


def clip_noise_schedule(alphas2, clip_value=0.001):
    """
    For a noise schedule given by alpha^2, this clips alpha_t / alpha_t-1. This may help improve stability during
    sampling.
    """
    alphas2 = np.concatenate([np.ones(1), alphas2], axis=0)
    alphas_step = (alphas2[1:] / alphas2[:-1])
    alphas_step = np.clip(alphas_step, a_min=clip_value, a_max=1.)
    alphas2 = np.cumprod(alphas_step, axis=0)
    return alphas2


def cosine_beta_schedule(timesteps, s=0.008, raise_to_power: float = 1):
    """ Cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ. """
    steps = timesteps + 2
    x = np.linspace(0, steps, steps)
    alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    betas = np.clip(betas, a_min=0, a_max=0.999)
    alphas = 1. - betas
    alphas_cumprod = np.cumprod(alphas, axis=0)

    if raise_to_power != 1:
        alphas_cumprod = np.power(alphas_cumprod, raise_to_power)

    return alphas_cumprod

"""
def cosine_beta_schedule_discrete(timesteps, nu_arr, s=0.008):
    #Cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ.
    steps = timesteps + 2
    x = np.linspace(0, steps, steps)
    x = np.expand_dims(x, 0)  # ((1, steps))

    nu_arr = np.array(nu_arr)  # (components, )  # X, charges, E, y, pos
    nu_arr = np.expand_dims(nu_arr, 1)  # ((components, 1))

    alphas_cumprod = np.cos(0.5 * np.pi * (((x / steps) ** nu_arr) + s) / (1 + s)) ** 2  # ((components, steps))
    # divide every element of alphas_cumprod by the first element of alphas_cumprod
    alphas_cumprod_new = alphas_cumprod / np.expand_dims(alphas_cumprod[:, 0], 1)
    # remove the first element of alphas_cumprod and then multiply every element by the one before it
    alphas = (alphas_cumprod_new[:, 1:] / alphas_cumprod_new[:, :-1])

    betas = 1 - alphas  # ((components, steps)) # X, charges, E, y, pos
    betas = np.swapaxes(betas, 0, 1)
    # plt.figure()
    # plt.plot(x[0, 1:], alphas[-1, ...], label='alpha')
    # plt.plot(x[0, 1:], betas[..., -1], label='betas')
    # plt.plot(x[0, ], alphas_cumprod[-1, ...], label='alpha_bar')
    # plt.show()
    # assert False
    return betas
"""

def cosine_beta_schedule_discrete(timesteps, nu_arr, s=0.008, compute_full = False):
    """ Cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ. """
    steps = timesteps + 1
    x = np.linspace(0, steps, steps)
    x = np.expand_dims(x, 0)  # ((1, steps))
    #print("x =", x)

    nu_arr = np.array(nu_arr)  # (components, )  # X, charges, E, y, pos
    nu_arr = np.expand_dims(nu_arr, 1)  # ((components, 1))

    alphas_cumprod = np.cos(0.5 * np.pi * (((x / steps) ** nu_arr) + s) / (1 + s)) ** 2  # ((components, steps))
    #print(f"alphas_cumprod = {alphas_cumprod[0,:]}")
    # divide every element of alphas_cumprod by the first element of alphas_cumprod
    alphas_cumprod_new = alphas_cumprod / np.expand_dims(alphas_cumprod[:, 0], 1)
    #print(f"alphas_cumprod_new = {alphas_cumprod_new[0,:]}")
    # remove the first element of alphas_cumprod and then multiply every element by the one before it
    alphas = (alphas_cumprod_new[:, 1:] / alphas_cumprod_new[:, :-1])
    #print(f"alphas = {alphas_cumprod_new[0, 1:]} / {alphas_cumprod_new[0, :-1]} = {alphas[0, :]}")

    n_comps = alphas.shape[0]
    true_alphas = np.ones((n_comps, alphas.shape[1] + 1))
    true_alphas[:, 1:] = alphas

    # print(alphas)
    betas = 1 - true_alphas  # ((components, steps+1)) # X,  E, y, charges, pos
    
    # I know this step feels useless but it ensures consistency with MiDi's method
    true_alphas[:,1:] = 1 - betas[:,1:].clip(0, 0.9999)

    betas = np.swapaxes(betas, 0, 1)# ((steps+1, components))

    if(compute_full):
        alphas_cumprod = alphas_cumprod_new.swapaxes(-1, -2)    #(steps+1, components)
        alphas_cumprod_new = alphas_cumprod_new[np.newaxis, ...]    #(1, steps+1, components)
        alphas_cumprod_new = alphas_cumprod_new.swapaxes(-1, -2)    #(1, components, steps+1)

        #timesteps since the rows are s and we have T-1 starts + another one for s=0 
        #timesteps + 1 since we have one column of padding
        alphas_bar = np.ones((n_comps,  timesteps, timesteps+1))

        for c in range(0, n_comps):
            for t in range(1, timesteps+1):
                alphas_bar[c, t-1, t:] = np.cumprod(true_alphas[c, t:])

        #(n_comps, timesteps+1, timesteps) = (n_comps, t, s)
        alphas_cumprod = alphas_bar.swapaxes(-1, -2)
    else:
        #(n_comps, timesteps+1, 1) = (n_comps, t, 1)
        alphas_cumprod = alphas_cumprod[..., np.newaxis]

    true_alphas = np.swapaxes(true_alphas, 0, 1)
    
    # print(f"alphas_cumprod shape: {alphas_cumprod.shape}")
    # print(f"alphas_cumprod: {alphas_cumprod}")
    # print("alphas", true_alphas)

    return betas, true_alphas, alphas_cumprod


def gaussian_KL(q_mu, q_sigma):
    """Computes the KL distance between a normal distribution and the standard normal.
        Args:
            q_mu: Mean of distribution q.
            q_sigma: Standard deviation of distribution q.
            p_mu: Mean of distribution p.
            p_sigma: Standard deviation of distribution p.
        Returns:
            The KL distance, summed over all dimensions except the batch dim.
        """
    return torch.log(1 / q_sigma) + 0.5 * (q_sigma ** 2 + q_mu ** 2) - 0.5


def cdf_std_gaussian(x):
    return 0.5 * (1. + torch.erf(x / math.sqrt(2)))


def SNR(gamma):
    """Computes signal to noise ratio (alpha^2/sigma^2) given gamma."""
    return torch.exp(-gamma)


def inflate_batch_array(array, target_shape):
    """
    Inflates the batch array (array) with only a single axis (i.e. shape = (batch_size,), or possibly more empty
    axes (i.e. shape (batch_size, 1, ..., 1)) to match the target shape.
    """
    target_shape = (array.size(0),) + (1,) * (len(target_shape) - 1)
    return array.view(target_shape)


def sigma(gamma, target_shape):
    """Computes sigma given gamma."""
    return inflate_batch_array(torch.sqrt(torch.sigmoid(gamma)), target_shape)


def alpha(gamma, target_shape):
    """Computes alpha given gamma."""
    return inflate_batch_array(torch.sqrt(torch.sigmoid(-gamma)), target_shape)


def check_mask_correct(variables, node_mask):
    for i, variable in enumerate(variables):
        if len(variable) > 0:
            assert_correctly_masked(variable, node_mask)


def check_tensor_same_size(*args):
    for i, arg in enumerate(args):
        if i == 0:
            continue
        assert args[0].size() == arg.size()


def sigma_and_alpha_t_given_s(gamma_t: torch.Tensor, gamma_s: torch.Tensor, target_size: torch.Size):
    """
    Computes sigma t given s, using gamma_t and gamma_s. Used during sampling.

    These are defined as:
        alpha t given s = alpha t / alpha s,
        sigma t given s = sqrt(1 - (alpha t given s) ^2 ).
    """
    sigma2_t_given_s = inflate_batch_array(
        -torch.expm1(F.softplus(gamma_s) - F.softplus(gamma_t)), target_size
    )

    # alpha_t_given_s = alpha_t / alpha_s
    log_alpha2_t = F.logsigmoid(-gamma_t)
    log_alpha2_s = F.logsigmoid(-gamma_s)
    log_alpha2_t_given_s = log_alpha2_t - log_alpha2_s

    alpha_t_given_s = torch.exp(0.5 * log_alpha2_t_given_s)
    alpha_t_given_s = inflate_batch_array(alpha_t_given_s, target_size)

    sigma_t_given_s = torch.sqrt(sigma2_t_given_s)

    return sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s


def reverse_tensor(x):
    return x[torch.arange(x.size(0) - 1, -1, -1)]


def check_issues_norm_values(gamma, norm_val1, norm_val2, num_stdevs=8):
    """ Check if 1 / norm_value is still larger than 10 * standard deviation. """
    zeros = torch.zeros((1, 1))
    gamma_0 = gamma(zeros)
    sigma_0 = sigma(gamma_0, target_shape=zeros.size()).item()
    max_norm_value = max(norm_val1, norm_val2)
    if sigma_0 * num_stdevs > 1. / max_norm_value:
        raise ValueError(
            f'Value for normalization value {max_norm_value} probably too '
            f'large with sigma_0 {sigma_0:.5f}*{num_stdevs} and '
            f'1 / norm_value = {1. / max_norm_value}')


def sample_discrete_features(probX, probE, node_mask, prob_charges = None):
    ''' Sample features from multinomial distribution with given probabilities (probX, probE, proby)
        :param probX: bs, n, dx_out        node features
        :param probE: bs, n, n, de_out     edge features
        :param proby: bs, dy_out           global features.
    '''
    bs, n = node_mask.shape
    # Noise X
    # The masked rows should define probability distributions as well
    probX[~node_mask] = 1 / probX.shape[-1]

    # Flatten the probability tensor to sample with multinomial
    probX = probX.reshape(bs * n, -1)       # (bs * n, dx_out)
    
    # Sample X
    X_t = probX.multinomial(1)                                  # (bs * n, 1)
    X_t = X_t.reshape(bs, n)     # (bs, n)

    # Noise E
    # The masked rows should define probability distributions as well
    inverse_edge_mask = ~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2))
    diag_mask = torch.eye(n).unsqueeze(0).expand(bs, -1, -1)

    probE[inverse_edge_mask] = 1 / probE.shape[-1]
    probE[diag_mask.bool()] = 1 / probE.shape[-1]

    probE = probE.reshape(bs * n * n, -1)

    # Sample E
    E_t = probE.multinomial(1).reshape(bs, n, n)   # (bs, n, n)
    E_t = torch.triu(E_t, diagonal=1)
    E_t = (E_t + torch.transpose(E_t, 1, 2))

    charges_t = None
    if(prob_charges != None):
        # The masked rows should define probability distributions as well
        prob_charges[~node_mask] = 1 / prob_charges.shape[-1]

        # Flatten the probability tensor to sample with multinomial
        prob_charges = prob_charges.reshape(bs * n, -1)  # (bs * n, dx_out)

        # Sample charges
        charges_t = prob_charges.multinomial(1)
        charges_t = charges_t.reshape(bs, n)

    return PlaceHolder(X=X_t, E=E_t, y=torch.zeros(bs, 0, device=X_t.device), charges=charges_t,  pos=None)


def compute_posterior_distribution(M, M_t, Qt_M, Qsb_M, Qtb_M):
    ''' M: X, E or charges
        Compute xt @ Qt.T * x0 @ Qsb / x0 @ Qtb @ xt.T
    '''
    # Flatten feature tensors
    M = M.flatten(start_dim=1, end_dim=-2).to(torch.float32)        # (bs, N, d) with N = n or n * n
    M_t = M_t.flatten(start_dim=1, end_dim=-2).to(torch.float32)    # same

    Qt_M_T = torch.transpose(Qt_M, -2, -1)      # (bs, d, d)

    left_term = M_t @ Qt_M_T   # (bs, N, d)
    right_term = M @ Qsb_M     # (bs, N, d)
    product = left_term * right_term    # (bs, N, d)

    denom = M @ Qtb_M     # (bs, N, d) @ (bs, d, d) = (bs, N, d)
    denom = (denom * M_t).sum(dim=-1)   # (bs, N, d) * (bs, N, d) + sum = (bs, N)
    # denom = product.sum(dim=-1)
    # denom[denom == 0.] = 1

    # stabilizing never hurts
    prob = product / (1e-19 + denom.unsqueeze(-1))    # (bs, N, d)

    return prob


def compute_batched_over0_posterior_distribution(X_t, Qt, Qsb, Qtb, use_ins_del=False, verbose=False):
    """ M: X, E or charges
        Compute xt @ Qt.T * x0 @ Qsb / x0 @ Qtb @ xt.T for each possible value of x0
        X_t: bs, n, dt          or bs, n, n, dt
        Qt: bs, d_t-1, dt
        Qsb: bs, d0, d_t-1
        Qtb: bs, d0, dt.
    """
    # Flatten feature tensors
    # Careful with this line. It does nothing if X is a node feature. If X is an edge features it maps to
    # bs x (n ** 2) x d
    X_t = X_t.flatten(start_dim=1, end_dim=-2).to(torch.float32)            # bs x N x dt

    if use_ins_del:
        A, B, C, D, alpha_bt, Zt, Ztb, Ztm1b = Qtb
        _, _, _, _, alpha_bs, Zs, Zsb, Zsm1b = Qsb

        if(verbose):
            torch.set_printoptions(threshold=100000,linewidth=2000)
            print("1-Zt:\n", (1-Zt).squeeze(-1))
            print("Ztm1b:\n", (Ztm1b).squeeze(-1))

        A_term = X_t*(alpha_bt*Ztb)
        B_term = X_t*((1-alpha_bt)*Ztb)
        C_term = X_t*((1-Zt)*Ztm1b)
        D_term = X_t*(1-Ztm1b)
    
    Qt_T = Qt.transpose(-1, -2)                 # bs, dt, d_t-1
    left_term = X_t @ Qt_T                      # bs, N, d_t-1

    if use_ins_del :
        # In ins-del, the "true" Qt is
        # zt*(at*A + (1-at)*B) + (1-zt)*C = zt*Qt + (1-zt)*C
        left_term = Zt*left_term + ((1-Zt)*X_t)@C.transpose(-1, -2)

    left_term = left_term.unsqueeze(dim=2)      # bs, N, 1, d_t-1

    if(verbose):
        print(f"Qt_T = {(Qt_T).size()}:\n", Qt_T)
        print(f"X_t = {X_t.size()}:\n", X_t)
        print(f"Zt = {Zt.size()}\n", Zt.squeeze(-1))
        print(f"left_term=X_t @ Qt_T (size:{left_term.size()}):\n", left_term)

    # value unpacking
    if use_ins_del:
        # The right term is effectively Qbar_{t-1|s}. We make one for each item
        # in x_t since each one has a different activation time.
        right_term = (Zsb*alpha_bs)[:,:,:,None]*A + (Zsb*(1-alpha_bs))[:,:,:,None]*B + \
                     ((1-Zs)*Zsm1b)[:,:,:,None]*C + (1-Zsm1b)[:,:,:,None]*D
        if(verbose):
            print(f"right_term Qbar_t-1|s (size:{right_term.size()}):\n", right_term)
    else:
        right_term = Qsb.unsqueeze(1)               # bs, 1, d0, d_t-1

    numerator = left_term * right_term          # bs, N, d0, d_t-1

    if use_ins_del:
        # This is pretty much the denominator
        A_term_T = A_term.transpose(-1, -2)
        B_term_T = B_term.transpose(-1, -2)
        C_term_T = C_term.transpose(-1, -2)
        D_term_T = D_term.transpose(-1, -2)
        prod     = A@A_term_T + B@B_term_T + C@C_term_T + D@D_term_T
        if(verbose):
            print(f"A_term_T (size:{A_term_T.size()})\n{A_term_T}")
            print(f"A = \n{A}")
            print(f"B_term_T (size:{B_term_T.size()})\n{B_term_T}")
            print(f"B = \n{B}")
            print(f"C_term_T (size:{C_term_T.size()})\n{C_term_T}")
            print(f"C = \n{C}")
            print(f"D_term_T (size:{D_term_T.size()})\n{D_term_T}")
            print(f"D = \n{D}")
            print(f"prod Qtb @ xt.T (size:{prod.size()}):\n{prod}")
    else:
        X_t_transposed = X_t.transpose(-1, -2)      # bs, dt, N
        prod = Qtb @ X_t_transposed                 # bs, d0, N

    prod = prod.transpose(-1, -2)               # bs, N, d0
    denominator = prod.unsqueeze(-1)            # bs, N, d0, 1
    denominator[denominator == 0] = 1e-6

    out = numerator / denominator
    if(verbose):
        print(f"out (size:{out.size()})", out)
    # print("Remember to put out back again!")
    return out


def mask_distributions(probs, node_mask):
    # Set masked rows to arbitrary distributions, so it doesn't contribute to loss
    device = probs.X.device
    row_X = torch.zeros(probs.X.size(-1), dtype=torch.float, device=device)
    row_X[0] = 1.

    row_E = torch.zeros(probs.E.size(-1), dtype=torch.float, device=device)
    row_E[0] = 1.

    probs.X[~node_mask] = row_X

    diag_mask = ~torch.eye(node_mask.size(1), device=node_mask.device, dtype=torch.bool).unsqueeze(0)
    probs.E[~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2) * diag_mask), :] = row_E

    probs.X = probs.X + 1e-7
    probs.X = probs.X / torch.sum(probs.X, dim=-1, keepdim=True)

    probs.E = probs.E + 1e-7
    probs.E = probs.E / torch.sum(probs.E, dim=-1, keepdim=True)

    #TODO: is this safe? should we add the cfg structure as a parameter and check there?
    if(probs.charges != None):
        row_charges = torch.zeros(probs.charges.size(-1), dtype=torch.float, device=device)
        row_charges[0] = 1.

        probs.charges[~node_mask] = row_charges

        probs.charges = probs.charges + 1e-7
        probs.charges = probs.charges / torch.sum(probs.charges, dim=-1, keepdim=True)

    return probs


def posterior_distributions(clean_data, noisy_data, Qt, Qsb, Qtb):
    prob_X = compute_posterior_distribution(M=clean_data.X, M_t=noisy_data.X,
                                            Qt_M=Qt.X, Qsb_M=Qsb.X, Qtb_M=Qtb.X)   # (bs, n, dx)
    prob_E = compute_posterior_distribution(M=clean_data.E, M_t=noisy_data.E,
                                            Qt_M=Qt.E, Qsb_M=Qsb.E, Qtb_M=Qtb.E)   # (bs, n * n, de)

    prob_c = None
    if(clean_data.charges != None):
        #TODO: is this safe? should we add the cfg structure as a parameter and check there?
        prob_c = compute_posterior_distribution(M=clean_data.charges, M_t=noisy_data.charges,
                                                Qt_M=Qt.charges, Qsb_M=Qsb.charges, Qtb_M=Qtb.charges)   # (bs, n * n, de)

    return PlaceHolder(X=prob_X, E=prob_E, charges=prob_c, y=None, pos=None)

# T = number of diffusion steps
# D = center of the transition
# w = the lower, the sharper the transition
def compute_linear_zeta(T, D, w, thr=1e-5):
    x0 = 1
    x1 = 0

    t = np.linspace(0, 1, T)

    sigmaD = 1.0 / (1.0 + np.exp(-(D - t) / w))

    z =  x0 + (x1 - x0)*(1 - sigmaD)
    z[-1] = 0
    z[z < thr] = 0
    z[z > 1-thr] = 1
    
    dz = np.gradient(z)
    dz = abs(dz)
    
    z = np.concatenate(([x0], z))
    dz = np.concatenate(([0], dz))
    
    z[-1] = 0
    dz[-1] = 0
    
    dz[dz < thr] = 0
    dz[dz > 1-thr] = 1
    dz /= dz.sum()
    
    return z.T, dz.T

# x is expected to have size (1, T+1)
def compute_zeta_bar_matrix(z, thr=1e-12):
    T = z.shape[0]
    z_bar = np.zeros((T-1, T))
    for i in range(T-1):
        a = np.cumprod(z[i+1:])

        # TODO: find a better way to deal with numerical convergence toward 0?
        a[a <= thr] = thr
        a[-1] = 0

        z_bar[i, i+1:] = a
        z_bar[i, i] = 1
        
    return z_bar.T

def compute_p_nodes(max_n, p_min, p_max=1, allow_zero_nodes=True):
    to_return = np.zeros((max_n + 1, max_n + 1))
    v = np.linspace(p_min, p_max, max_n + 1)
    v = np.concatenate([v, np.flip(v)[1:]])
    
    for i in range(0, max_n+1):
        a = v[(max_n-i):(2*max_n+1-i)].copy()
        if not allow_zero_nodes:
            a[0] = 0
        to_return[i] = a/sum(a)
        
    return to_return

def compute_delete_probs(zetas, zetas_bar):
    T = zetas.shape[0] - 1 #betas_bar has shape T+1 > must remove 1
    print("zetas.shape = ", zetas.shape)
    print("zetas_bar.shape = ", zetas_bar.shape)

    #T since the starting timestep can be anything between step 0 and T-1
    #T+2 since we have 1 slot for pad, T for 
    probs = np.zeros((T, T+2))

    for s in range(0, T):
        #"standard" slots: probability of being deleted at step j
        #if I got created at step i
        for t in range(s+1, T+1):
            #zetas_bar[t,s] since zetas_bar has s as rows and t as columns
            probs[s,t] = zetas[t]/(1 + zetas[t] + zetas_bar[t,s])

        #probability of surviving
        probs[s,-1] = 1/(1+ zetas[-1] + zetas_bar[-1,s])

    summed_cols = np.abs(np.sum(probs, axis=-1) - 1)
    final_sum = np.sum(summed_cols)
    assert((final_sum <= 1e-12)), "compute_delete_probs does not sum to 1"

    return probs.T

# Unused
def compute_s_loss(T, max_margin_fraction):
    s_loss = np.zeros((T,T), dtype=np.float16)
    
    for s in range(0, T):
        s_loss[s, s] = 1
        
        margin_steps = (T-s)*max_margin_fraction
        margin_steps = math.floor(margin_steps)
        
        # print(f"t={t}, s={s}, margin_steps:{margin_steps}")
        
        if(margin_steps < 2): continue
        
        l_data = np.linspace(0, 0.5, num=margin_steps)
        r_data = l_data[::-1]
        
        if(s > 0):
            l_start = s-margin_steps
            if(l_start < 0):
                l_start = 0
                l_data  = l_data[-(s):]
            else:
                l_start = s-margin_steps
                
            # print(f"l_start:{l_start}, s={s}, l_data={l_data}")
            s_loss[s, l_start:s] = l_data
        
        if(s+1 < T):
            r_end = s+1+margin_steps
            if(r_end >= T):
                r_end = T-1
                r_data = r_data[(T-s):]
            else:
                r_end = s+1+margin_steps
                
            s_loss[s,s+1:r_end] = r_data
    
    return s_loss