import torch
import numpy as np
from scipy.ndimage import gaussian_filter1d

def window_energy_dist_loss(output, target, input_neo, choice='gaussian', base_loss='mse', sampling_rate=30000, smooth_sigma=2.0):
    
    B, L = output.shape

    # low filter or smooth
    neo_np = input_neo.detach().cpu().numpy()
    neo_smoothed = np.stack([gaussian_filter1d(neo_np[i], sigma=smooth_sigma) for i in range(B)], axis=0)
    energy = torch.tensor(neo_smoothed, dtype=torch.float32, device=output.device)

    energy = (energy - energy.min(dim=1, keepdim=True)[0]) / (energy.max(dim=1, keepdim=True)[0] - energy.min(dim=1, keepdim=True)[0] + 1e-8)
    x = torch.linspace(0, L - 1, L, device=output.device).unsqueeze(0).repeat(B, 1)

    # adpative mu for avoiding two spikes in one window when with low sr, like below 10000 sr 
    center = int(L / 2)
    radius = int((3.0 * sampling_rate / 1000.0) / 2)
    start = max(0, center - radius)
    end = min(L, center + radius)

    mu = torch.argmax(energy[:, start:end], dim=1, keepdim=True).float() + start

    # there is not much matters what dist we used to fit the neo, the neo matters, or to say, the attention matters
    if choice == 'laplace':
        b = torch.sum(torch.abs(x - mu) * energy, dim=1, keepdim=True) / (torch.sum(energy, dim=1, keepdim=True) + 1e-8)
        b = torch.clamp(b, min=3.0)
        weights = torch.exp(-torch.abs(x - mu) / b)
    elif choice == 'gaussian':
        var = torch.sum((x - mu)**2 * energy, dim=1, keepdim=True) / (torch.sum(energy, dim=1, keepdim=True) + 1e-8)
        std = torch.sqrt(var + 1e-8)
        std = torch.clamp(std, min=5.0)
        weights = torch.exp(-0.5 * ((x - mu) / std) ** 2)
    elif choice == 'cauchy':
        gamma = torch.sum(torch.abs(x - mu) * energy, dim=1, keepdim=True) / (torch.sum(energy, dim=1, keepdim=True) + 1e-8)
        gamma = torch.clamp(gamma, min=3.0)
        relative_pos = (x - mu) / gamma
        weights = 1 / (1 + relative_pos ** 2)

    weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-8)

    # apply
    if base_loss == 'l1':
        loss_element = torch.abs(output - target)
    elif base_loss == 'mse':
        loss_element = (output - target) ** 2

    loss = (loss_element * weights).mean()
    return loss

def window_energy_direct_loss(output, target, input_neo, base_loss='mse', smooth_sigma=2.0):
    B, L = output.shape

    # Smooth neo with Gaussian filter
    neo_np = input_neo.detach().cpu().numpy()
    neo_smoothed = np.stack([gaussian_filter1d(neo_np[i], sigma=smooth_sigma) for i in range(B)], axis=0)
    energy = torch.tensor(neo_smoothed, dtype=torch.float32, device=output.device)

    # Normalize energy per sample to [0, 1]
    energy = (energy - energy.min(dim=1, keepdim=True)[0]) / (
        energy.max(dim=1, keepdim=True)[0] - energy.min(dim=1, keepdim=True)[0] + 1e-8
    )

    # Normalize to make it a distribution (i.e., sum to 1)
    weights = energy / (energy.sum(dim=1, keepdim=True) + 1e-8)

    # Apply weighted loss
    if base_loss == 'l1':
        loss_element = torch.abs(output - target)
    elif base_loss == 'mse':
        loss_element = (output - target) ** 2
    else:
        raise ValueError("Invalid base_loss, must be 'l1' or 'mse'")

    loss = (loss_element * weights).mean()
    return loss
