import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import torchaudio

device = "cuda" if torch.cuda.is_available() else "cpu"

upper_limit, lower_limit = 1, -1

def attack_phase_spec(model, audio, label, epsilon, alpha, attack_iters, restarts, target_freq):
    # STFT parameters
    n_fft = 1024
    hop_length = 512
    win_length = 1024

    # Perform STFT
    window = torch.hann_window(win_length)
    window = window.to(device)
    # print(audio.shape)
    spec = torch.stft(audio, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, center=True,
                      normalized=False, onesided=True, return_complex=True)

    # Define parameters for perturbation
    # upper_limit, lower_limit = 1, -1
    # epsilon = 0.01
    # alpha = 0.002
    # target_freq = 4000
    index = int(np.ceil(target_freq * n_fft / 16000))
    bandwidth = 2000
    end_index = index + int(np.ceil(bandwidth * n_fft / 16000))
    # end_index = spec.shape[1]

    # restarts = 1
    # attack_iters = 2

    magnitude = torch.abs(spec)
    phase = torch.angle(spec)

    batch_size, num_frequencies, num_frames = magnitude.shape
    assert magnitude.shape == phase.shape

    max_loss = torch.zeros(batch_size)
    # max_delta = torch.zeros(batch_size, num_frequencies - index, num_frames)
    # max_delta = torch.zeros(batch_size, index, num_frames)
    max_delta = torch.zeros(batch_size, end_index - index, num_frames)

    max_loss = max_loss.to(device)
    max_delta = max_delta.to(device)

    for _ in range(restarts):
        # perturbation = torch.zeros(batch_size, num_frequencies - index, num_frames)
        # perturbation = torch.zeros(batch_size, index, num_frames)
        perturbation = torch.zeros(batch_size, end_index - index, num_frames)
        perturbation = perturbation.uniform_(-epsilon, epsilon).requires_grad_(True)

        perturbation = perturbation.to(device)

        for i in range(attack_iters):
            perturbed_phase = phase.clone()
            # perturbed_magnitude[:, index:, :] += perturbation
            # perturbed_magnitude[:, :index, :] += perturbation
            perturbed_phase[:, index:end_index, :] += perturbation
            perturbed_phase.retain_grad()
            perturbation.retain_grad()

            perturbed_spec = perturbed_phase * torch.exp(1j * phase)

            new_audio = torch.istft(perturbed_spec, n_fft=n_fft, hop_length=hop_length, win_length=win_length,
                                    window=window, center=True, normalized=False, onesided=True)

            # Forward pass
            output = model(new_audio)

            # Compute loss
            loss = F.cross_entropy(output, label)

            # Backward pass
            loss.backward()

            # Update perturbation
            grad = perturbation.grad.detach()

            p = perturbation + alpha * torch.sign(grad)
            p = torch.clamp(p, -epsilon, epsilon)

            perturbation.data = p
            perturbation.grad.zero_()

            all_loss = F.cross_entropy(model(new_audio), label, reduction='none')
        max_delta[all_loss >= max_loss] = perturbation.detach()[all_loss >= max_loss]
        max_loss = torch.max(max_loss, all_loss)

    # magnitude[:, index:, :] += max_delta
    # magnitude[:, :index, :] += max_delta
    phase[:, index:end_index, :] += max_delta
    max_perturbed_audio = torch.istft(magnitude * torch.exp(1j * phase), n_fft=n_fft, hop_length=hop_length,
                                      win_length=win_length,
                                      window=window, center=True, normalized=False, onesided=True)

    return max_perturbed_audio
