def stretched_gaussian_like(x, sigma_target, stretch_factor=1.6):
    g = torch.randn_like(x)                # Standard Gaussian
    g = torch.sign(g) * (torch.abs(g) ** (1/stretch_factor))  # Stretch weg vom Zentrum
    g *= sigma_target / g.std()             # Reskalieren auf gewünschte σ
    return g

import torch
import numpy as np

def loss_fn(config, device, standart_dimensions):
    sigma_data  = config["sigma_data"]
    sigma_noise = config["sigma_noise"]
    dist        = config["dist"]
    precond     = config["precond"]
    use_t       = config["use_t"]

    trim        = config["trim"]
    robust      = config["robust"]
    min_k = config.get("min_k", 1)

    def loss_fn_norm(model, batch, eps=1e-5):
        dtype = torch.float32
        y = batch

        B = y.shape[0]
        normal = torch.ones([B, 1, 1], device=y.device, dtype=torch.float32)
        sigma = normal * sigma_noise

        if dist == "gaus":
            n = torch.randn_like(y) * sigma
        elif dist == "bigaus":
            n = stretched_gaussian_like(y, sigma_noise)
            # print("yes")
        elif dist == "Tabsyn":
            P_mean = -1.2
            P_std = 1.2
            rnd_normal = torch.randn([batch.shape[0], 1, 1], device=batch.device)
            sigma = (rnd_normal * P_std + P_mean).exp()
            n = torch.randn_like(y) * sigma
        elif dist == "rand":
            eps=1e-5
            sigma = normal * (torch.rand(1, device = batch.device) * (1. - eps) + eps)
            n = torch.randn_like(y) * sigma
        elif dist == "rand2":
            eps=1e-5
            sigma = torch.rand([batch.shape[0], 1, 1], device=batch.device) * (1. - eps) + eps
            n = torch.randn_like(y) * sigma
        elif dist == "rand3":
            eps=1e-5
            a, b = 0.12, 0.35
            sigma = a + (b - a) * (torch.rand([batch.shape[0], 1, 1], device=batch.device) * (1. - eps) + eps)
            n = torch.randn_like(y) * sigma
        elif dist == 'NCSBAD':
            eps=1e-5
            sigma_t = 0.01
            random_t = torch.rand([batch.shape[0], 1, 1], device=batch.device) * (1. - eps) + eps 
            sigma = torch.sqrt((sigma_t**(2 * random_t) - 1.) / 2. / np.log(sigma_t))
            n = torch.randn_like(y) * sigma
        else:
            raise ValueError(f"Unknown dist={dist}")

        x = y + n

        if precond == "yes":
            weight = (sigma ** 2 + sigma_data ** 2) / (sigma * sigma_data) ** 2
            c_skip = sigma_data ** 2 / (sigma ** 2 + sigma_data ** 2)
            c_out = sigma * sigma_data / (sigma ** 2 + sigma_data ** 2).sqrt()
            c_in = 1 / (sigma_data ** 2 + sigma ** 2).sqrt()

            sigma = sigma.to(torch.float32).reshape(-1, 1, 1)

        elif precond == "one":
            weight = normal
            c_skip = normal
            c_out = normal
            c_in = normal

            sigma = sigma.to(torch.float32).reshape(-1, 1, 1)
        
        elif precond == "no":
            weight = normal
            c_skip = torch.zeros([batch.shape[1], 1, 1], device=batch.device)
            c_out = normal
            c_in = normal

            sigma = sigma.to(torch.float32).reshape(-1, 1, 1)
        
        else:
            assert False
       

        x_in = c_in * x

        if not use_t:
            F_x = model(x_in.to(dtype))
        else:
            # if you later pass t, do it here
            c_noise = sigma.log() / 4
            F_x = model((x_in).to(dtype), c_noise.flatten())

        D_x = c_skip * x + c_out * F_x.to(torch.float32)

        # elementwise squared error, weighted
        elem_loss = weight * (D_x - y).pow(2)   # shape: (B, ...)

        # ---- IMPORTANT: compute per-sample scalar loss (B,) ----
        per_sample = elem_loss.view(B, -1).mean(dim=1)  # mean over all non-batch dims

        # ---- robust aggregation across batch ----
        if robust == "None" or trim <= 0.0:
            loss = per_sample.mean()

        elif robust == "trim":
            # discard top trim fraction
            k = int(np.floor((1.0 - trim) * B))
            k = max(k, min_k)
            # keep smallest k losses
            loss = torch.topk(per_sample, k, largest=False).values.mean()

        else:
            raise ValueError(f"Unknown robust={robust}")

        return loss

    return loss_fn_norm

    