import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import torchaudio


def spectrum_attack(model, audio, label, epsilon, alpha, attack_iters, restarts, target_freq):
    # STFT parameters
    device = audio.device
    n_fft = 1024
    hop_length = 512
    win_length = 1024

    # Perform STFT
    window = torch.hann_window(win_length)
    window = window.to(device)

    spec = torch.stft(audio, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, center=True,
                      normalized=False, onesided=True, return_complex=True)

    index = int(np.ceil(target_freq * n_fft / 16000))

    magnitude = torch.abs(spec)
    phase = torch.angle(spec)

    batch_size, num_frequencies, num_frames = magnitude.shape

    max_loss = torch.zeros(batch_size)
    max_delta = torch.zeros(batch_size, num_frequencies - index, num_frames)
    # max_delta = torch.zeros(batch_size, index, num_frames)

    max_loss = max_loss.to(device)
    max_delta = max_delta.to(device)

    for _ in range(restarts):
        perturbation = torch.zeros(batch_size, num_frequencies - index, num_frames)
        # perturbation = torch.zeros(batch_size, index, num_frames)
        perturbation = perturbation.uniform_(-epsilon, epsilon).requires_grad_(True)

        perturbation = perturbation.to(device)

        for i in range(attack_iters):
            perturbed_magnitude = magnitude.clone()
            perturbed_magnitude[:, index:, :] += perturbation
            # perturbed_magnitude[:, :index, :] += perturbation
            perturbed_magnitude.retain_grad()
            perturbation.retain_grad()

            perturbed_spec = perturbed_magnitude * torch.exp(1j * phase)

            new_audio = torch.istft(perturbed_spec, n_fft=n_fft, hop_length=hop_length, win_length=win_length,
                                    window=window, center=True, normalized=False, onesided=True)

            # Forward pass
            output = model(new_audio)

            # Compute loss
            loss = F.cross_entropy(output, label)

            # Backward pass
            loss.backward()

            # Update perturbation
            grad = perturbation.grad.detach()

            p = perturbation + alpha * torch.sign(grad)
            p = torch.clamp(p, -epsilon, epsilon)

            perturbation.data = p
            perturbation.grad.zero_()

            all_loss = F.cross_entropy(model(new_audio), label, reduction='none')
        max_delta[all_loss >= max_loss] = perturbation.detach()[all_loss >= max_loss]
        max_loss = torch.max(max_loss, all_loss)

    magnitude[:, index:, :] += max_delta
    # magnitude[:, :index, :] += max_delta
    max_perturbed_audio = torch.istft(magnitude * torch.exp(1j * phase), n_fft=n_fft, hop_length=hop_length,
                                      win_length=win_length,
                                      window=window, center=True, normalized=False, onesided=True)

    return max_perturbed_audio
