import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np
import librosa
from hyperparameters import args
import soundfile as sf


device = "cuda" if torch.cuda.is_available() else "cpu"


def time_stertch_batch(Xs, factor):
    results = []
    for i in range(Xs.shape[0]):
        x = Xs[i].squeeze().cpu().numpy()
        x = librosa.effects.time_stretch(x, rate=factor)
        # truncate to max_frames * 160 + 240
        x = x[:args.max_frames * 160 + 240]
        results.append(x)
    # stack to tensor
    results = torch.tensor(results).to(device)
    return results


def pitch_shift_batch(Xs, sample_rates, n_steps):
    results = []
    for i in range(Xs.shape[0]):
        x = Xs[i].squeeze().cpu().numpy()
        sample_rate = sample_rates[i]
        x = librosa.effects.pitch_shift(x, sr=sample_rate, n_steps=n_steps)
        # truncate to max_frames * 160 + 240
        x = x[:args.max_frames * 160 + 240]
        results.append(x)
    # stack to tensor
    results = torch.tensor(results).to(device)
    return results


def attack_discrete(model, X, y, sample_rates):
    if random.random() >= args.discrete_attack_ratio:
        return X

    # max_loss = -inf
    max_loss = 0
    X_adv_results = None

    # attack steps, only keep the best attack which has the highest loss
    for i in range(args.attack_iters):
        time_stretch_factor = random.uniform(args.time_stretch_min, args.time_stretch_max)
        n_steps = random.uniform(args.pitch_shift_min, args.pitch_shift_max)

        # time stretch or pitch shift
        if random.random() < 0.5:
            X_adv = time_stertch_batch(X, time_stretch_factor)
        else:
            X_adv = pitch_shift_batch(X, sample_rates, n_steps)

        # forward
        model.eval()
        with torch.no_grad():
            outputs = model(X_adv)
            loss = F.cross_entropy(outputs, y).item()
        if loss > max_loss:
            max_loss = loss
            X_adv_results = X_adv

        # save some attack results to listen
        os.makedirs("attack_results", exist_ok=True)
        x = X_adv[0].squeeze().cpu().numpy()
        name = random.randint(0, 100)
        sf.write(f"attack_results/{name}.wav", x, 16000)

    return X_adv_results
