def stretched_gaussian_like(x, sigma_target, stretch_factor=1.6):
    g = torch.randn_like(x)                # Standard Gaussian
    g = torch.sign(g) * (torch.abs(g) ** (1/stretch_factor))  # Strech
    g *= sigma_target / g.std()             # σ
    return g

import torch
import numpy as np

def interference_fn(config, device):
    sigma_data = config["sigma_data"]
    sigma_noise = config["sigma_noise"]
    dist = config["dist"]
    precond = config["precond"]
    num_iterations = config["num"]
    use_t = config['use_t']
    
    def interference(config, model, batch):
        extended_batch = batch.unsqueeze(0).repeat(num_iterations, 1, 1, 1).to(device)
        with torch.no_grad():
            dtype = torch.float32
            
            normal = torch.ones([batch.shape[0], 1, 1], device=batch.device)
            
            sigma = normal * sigma_noise

            y = extended_batch

            if dist == "gaus":
                n = torch.randn_like(y) * sigma
            elif dist == "bigaus":
                n = stretched_gaussian_like(y, sigma_noise)
            elif dist == "Tabsyn":
                P_mean = -1.2
                P_std = 1.2
                rnd_normal = torch.randn([batch.shape[0], 1, 1], device=batch.device)
                sigma = (rnd_normal * P_std + P_mean).exp() 
                n = torch.randn_like(y) * sigma
            elif dist == "rand":
                eps=1e-5
                sigma = normal * (torch.rand(1, device = batch.device) * (1. - eps) + eps)
                n = torch.randn_like(y) * sigma
            elif dist == "rand2":
                eps=1e-5
                sigma = torch.rand([batch.shape[0], 1, 1], device=batch.device) * (1. - eps) + eps
                n = torch.randn_like(y) * sigma
            elif dist == "rand3":
                eps=1e-5
                a, b = 0.12, 0.35
                sigma = sigma = a + (b - a) * (torch.rand([batch.shape[0], 1, 1], device=batch.device) * (1. - eps) + eps)
                n = torch.randn_like(y) * sigma
            elif dist == 'NCSBAD':
                eps=1e-5
                sigma_t = 0.01
                random_t = torch.rand([batch.shape[0], 1, 1], device=batch.device) * (1. - eps) + eps 
                random_t = random_t * 999
                sigma = torch.sqrt((sigma_t**(2 * random_t) - 1.) / 2. / np.log(sigma_t))
                n = torch.randn_like(y) * sigma
            else:
                assert False
            
            x = y + n
                        
            if precond == "yes":
                weight = (sigma ** 2 + sigma_data ** 2) / (sigma * sigma_data) ** 2
                c_skip = sigma_data ** 2 / (sigma ** 2 + sigma_data ** 2)
                c_out = sigma * sigma_data / (sigma ** 2 + sigma_data ** 2).sqrt()
                c_in = 1 / (sigma_data ** 2 + sigma ** 2).sqrt()

                sigma = sigma.to(torch.float32).reshape(-1, 1, 1)

            elif precond == "one":
                weight = normal
                c_skip = normal
                c_out = normal
                c_in = normal

                sigma = sigma.to(torch.float32).reshape(-1, 1, 1)
            
            elif precond == "no":
                weight = normal
                c_skip = torch.zeros([batch.shape[1], 1, 1], device=batch.device)
                c_out = normal
                c_in = normal

                sigma = sigma.to(torch.float32).reshape(-1, 1, 1)
            
            else:
                assert False
            

            x_in = c_in * x


            if not use_t:
                F_x = model((x_in).to(dtype))
            else:
                c_noise = sigma.log() / 4
                F_x = model((x_in).to(dtype), c_noise.flatten())


            D_y = c_skip * x + c_out * F_x.to(torch.float32)
            
            #__ precon_end___

            loss = (D_y - y)



            # ------------------------------
            # ADDITION 1: recon-only (n = 0)
            # ------------------------------
        #     y0 = batch.to(device)
        #     x0 = y0
        #     x0_in = c_in * x0
        #     F0 = model((x0_in).to(dtype))
        #     D0 = c_skip * x0 + c_out * F0.to(torch.float32)
        #     recon_loss = (D0 - y0)
        #     recon_score = torch.square(recon_loss)  # [B, ...] elementwise

        # # keep original output (MC MSE mean over iterations)
        # mc_score = torch.square(loss).mean(dim=0)

        # # ------------------------------
        # # ADDITION 2: Jacobian score
        # # ------------------------------
        # if True:
        #     # compute jacobian only if requested (it is slower)
        #     y_clean = batch.to(device).detach().requires_grad_(True)

        #     # recompute sigma + precond coeffs (add-only)
        #     normal2 = torch.ones([batch.shape[0], 1, 1], device=batch.device)
        #     sigma_clean = normal2 * sigma_noise

        #     if precond == "yes":
        #         c_skip_j = sigma_data ** 2 / (sigma_clean ** 2 + sigma_data ** 2)
        #         c_out_j  = sigma_clean * sigma_data / (sigma_clean ** 2 + sigma_data ** 2).sqrt()
        #         c_in_j   = 1 / (sigma_data ** 2 + sigma_clean ** 2).sqrt()
        #     else:
        #         assert False

        #     x_in_j = c_in_j * y_clean
        #     F_j = model((x_in_j).to(dtype))
        #     D_j = c_skip_j * y_clean + c_out_j * F_j.to(torch.float32)

        #     v = torch.randn_like(D_j)
        #     JTv = torch.autograd.grad(
        #         outputs=(D_j * v).sum(),
        #         inputs=y_clean,
        #         create_graph=False,
        #         retain_graph=False
        #     )[0]

        #     jac_score = (JTv ** 2).view(batch.shape[0], -1).sum(dim=1)  # [B]

        #     F = batch[0].numel()
        #     jac_score = jac_score / F


        #     reduce_dims = tuple(range(1, mc_score.ndim))  # all dims except batch
        #     mc_scalar    = mc_score.mean(dim=reduce_dims)        # [B]
        #     recon_scalar = recon_score.mean(dim=reduce_dims)     # [B]

        #     jac_scalar = jac_score  # [B]

        #     jac_map = jac_score.view(batch.shape[0], *([1] * (batch.ndim - 1))).expand_as(batch)

        #     alpha = float(sigma_noise)   # or sigma_noise**2 if that variable is in scope
        #     recon_jac_score = recon_score + alpha * jac_map


        #     assert mc_score.shape == batch.shape
        #     assert recon_score.shape == batch.shape
        #     assert jac_map.shape == batch.shape
        #     assert recon_jac_score.shape == batch.shape

        #     return recon_jac_score

        # # return mc_score
            
        return torch.square(loss).mean(dim=0)   

    return interference


def interference_fn2(config, device):
    sigma_data = config["sigma_data"]
    P_mean = config["P_mean"]
    P_std = config["P_std"] 
    
    def interference(config, model, batch):
        num_iterations = 15
        extended_batch = batch #.unsqueeze(0).repeat(num_iterations, 1, 1, 1).to(device)  # Shape: [100, 50, 1, 6]
        loss_list = []
        for i in range(num_iterations):
            with torch.no_grad():
                dtype = torch.float32
                if config["sigma_per_feature"]:
                    D_x, weight, x, sigma = model(batch, 0) #0 is placeholder for noiselevel (part of the model)
                    y = batch

                else:
                    normal = torch.ones([batch.shape[0], 1, 1], device=batch.device)
                    sigma = normal * 0.2
                    weight = (sigma ** 2 + sigma_data ** 2) / (sigma * sigma_data) ** 2 # ~10 bei 0,1 sigma_data, 1 bei 1

                    y = extended_batch

                    n = torch.randn_like(y) * sigma
                    
                    x = y + n

                    # ___ Precon__
                    x = x.to(torch.float32)

                    sigma = sigma.to(torch.float32).reshape(-1, 1, 1)
                    dtype = torch.float32

                    c_skip = sigma_data ** 2 / (sigma ** 2 + sigma_data ** 2) 
                    c_out = sigma * sigma_data / (sigma ** 2 + sigma_data ** 2).sqrt() 
                    c_in = 1 / (sigma_data ** 2 + sigma ** 2).sqrt() 


                    x_in = c_in * x
                
                    F_x = model((x_in).to(dtype))

                    D_x = c_skip * x + c_out * F_x.to(torch.float32)

                    #__ precon_end___

                    target = y

                    loss = (D_x - y)**2  

                loss_list.append(loss)
        loss = torch.stack(loss_list)
        return torch.abs(loss).sum(dim=0) / num_iterations

    return interference 