import torch
import json
import numpy as np
from model.RawNet3 import RawNet3_detect
from model.RawNetBasicBlock import Bottle2neck
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, roc_curve
import os
from torch.utils.data import Dataset
import librosa
from torch.nn import functional as F
import random
import torchaudio

device = "cuda" if torch.cuda.is_available() else "cpu"

upper_limit, lower_limit = 1, -1


def attack_spec(model, audio, label, epsilon, alpha, attack_iters, restarts, target_freq):
    # STFT parameters
    n_fft = 1024
    hop_length = 512
    win_length = 1024

    # Perform STFT
    window = torch.hann_window(win_length)
    window = window.to(device)
    spec = torch.stft(audio, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, center=True,
                      normalized=False, onesided=True, return_complex=True)

    # Define parameters for perturbation
    # upper_limit, lower_limit = 1, -1
    # epsilon = 0.01
    # alpha = 0.002
    # target_freq = 4000
    index = int(np.ceil(target_freq * n_fft / 16000))
    # restarts = 1
    # attack_iters = 2

    magnitude = torch.abs(spec)
    phase = torch.angle(spec)

    batch_size, num_frequencies, num_frames = magnitude.shape

    max_loss = torch.zeros(batch_size)
    max_delta = torch.zeros(batch_size, num_frequencies - index, num_frames)
    # max_delta = torch.zeros(batch_size, index, num_frames)

    max_loss = max_loss.to(device)
    max_delta = max_delta.to(device)

    for _ in range(restarts):
        perturbation = torch.zeros(batch_size, num_frequencies - index, num_frames)
        # perturbation = torch.zeros(batch_size, index, num_frames)
        perturbation = perturbation.uniform_(-epsilon, epsilon).requires_grad_(True)

        perturbation = perturbation.to(device)

        for i in range(attack_iters):
            perturbed_magnitude = magnitude.clone()
            perturbed_magnitude[:, index:, :] += perturbation
            # perturbed_magnitude[:, :index, :] += perturbation
            perturbed_magnitude.retain_grad()
            perturbation.retain_grad()

            perturbed_spec = perturbed_magnitude * torch.exp(1j * phase)

            new_audio = torch.istft(perturbed_spec, n_fft=n_fft, hop_length=hop_length, win_length=win_length,
                                    window=window, center=True, normalized=False, onesided=True)

            # Forward pass
            output = model(new_audio)

            # Compute loss
            loss = F.cross_entropy(output, label)

            # Backward pass
            loss.backward()

            # Update perturbation
            grad = perturbation.grad.detach()

            p = perturbation + alpha * torch.sign(grad)
            p = torch.clamp(p, -epsilon, epsilon)

            perturbation.data = p
            perturbation.grad.zero_()

            all_loss = F.cross_entropy(model(new_audio), label, reduction='none')
        max_delta[all_loss >= max_loss] = perturbation.detach()[all_loss >= max_loss]
        max_loss = torch.max(max_loss, all_loss)

    magnitude[:, index:, :] += max_delta
    # magnitude[:, :index, :] += max_delta
    max_perturbed_audio = torch.istft(magnitude * torch.exp(1j * phase), n_fft=n_fft, hop_length=hop_length,
                                      win_length=win_length,
                                      window=window, center=True, normalized=False, onesided=True)

    return max_perturbed_audio


class SpeechDataset(Dataset):
    def __init__(self, source_dir=None, fake_dir=None):
        self.data = []
        self.labels = []
        self.source_dir = source_dir
        self.fake_dir = fake_dir
        self.load_data()

    def load_data(self):
        if self.source_dir is not None:
            for dirpath, dirnames, filenames in os.walk(self.source_dir):
                for file in filenames:
                    if file.endswith(('.wav', '.mp3', '.flac', '.m4a')):
                        self.data.append(os.path.join(dirpath, file))
                        self.labels.append(0)  # 0 for real audio
        if self.fake_dir is not None:
            for dirpath, dirnames, filenames in os.walk(self.fake_dir):
                for file in filenames:
                    if file.endswith(('.wav', '.mp3', '.flac', '.m4a')):
                        self.data.append(os.path.join(dirpath, file))
                        self.labels.append(1)  # 1 for fake audio

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        try:
            wav_path = self.data[idx]
            # print(f"Processing file {wav_path}")
            audio, sample_rate = librosa.load(wav_path, sr=16000)
            if audio.ndim > 1:
                audio = librosa.to_mono(audio)

            # print(f"Processing audio with shape {audio.shape}")
        except:
            print(f"Error processing file {wav_path}")
            return None
        return audio, sample_rate, self.labels[idx]


def main():
    model_path = "/epoch_11.pth"

    source_dir = "./test_set/real_audio/"
    fake_dir = "./test_set/fake_audio/"

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load the model
    model = RawNet3_detect(encoder_type='ECA', nOut=256, sinc_stride=10, log_sinc=True, norm_sinc=True,
                           out_bn=True,
                           block=Bottle2neck, model_scale=8, context=True, summed=True)
    model.load_state_dict(torch.load(model_path)["model_state_dict"], strict=True)
    model = model.to(device)

    # load attack model
    # attack_model = RawNet3_detect(encoder_type='ECA', nOut=256, sinc_stride=10, log_sinc=True, norm_sinc=True,
    #                               out_bn=True,
    #                               block=Bottle2neck, model_scale=8, context=True, summed=True)
    # attack_model.load_state_dict(torch.load(attack_path)["model_state_dict"], strict=True)
    # attack_model = attack_model.to(device)

    # dir_yaml = 'model/rawnet2/model_config_RawNet2.yaml'
    #
    # with open(dir_yaml, 'r') as f_yaml:
    #     parser1 = yaml.load(f_yaml, Loader=yaml.SafeLoader)
    #
    # attack_model = RawNet(parser1['model'], device='cuda').to('cuda')

    # Load the data
    dataset = SpeechDataset(source_dir=source_dir, fake_dir=fake_dir)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

    # Evaluate the model
    model.eval()
    # attack_model.eval()

    # Attack parameters
    epsilon = 0.005
    alpha = 0.002
    restarts = 1
    attack_iters = 2

    attack_results_real = {}
    attack_results_fake = {}

    for attack_freq in range(4000, 4100, 100):

        y_true = []
        y_pred = []

        for i, (audio, sample_rate, label) in enumerate(dataloader):
            audio = audio.to(device)
            label = label.to(device)

            # audio length < 2400 * 160
            if audio.shape[1] > 2400 * 160:
                audio = audio[:, :2400 * 160]

            # Attack the audio
            max_perturbed_audio = attack_spec(model, audio, label, epsilon, alpha, attack_iters, restarts,
                                              target_freq=attack_freq)
            attacked_audio = torch.clamp(max_perturbed_audio, -1, 1).squeeze(0)

            # save the adversarial audio, 16k sample rate, 1%, use torchaudio
            # if random.random() < 0.01:
            #     os.makedirs("freq_attack", exist_ok=True)
            #     # save original audio
            #     torchaudio.save(f"freq_attack/original_audio_{i}.wav", audio.cpu(), sample_rate=16000)
            #     torchaudio.save(f"freq_attack/attacked_audio_{i}.wav", attacked_audio.unsqueeze(0).cpu(),
            #                     sample_rate=16000)

            # Compute the accuracy
            window_size = 600 * 160
            step = 300 * 160

            if attacked_audio.shape[0] < window_size:
                num_windows = 1

            else:
                num_windows = (attacked_audio.shape[0] - window_size) // step + 1
                num_windows = min(num_windows, 24)

            windows = []
            for i in range(num_windows):
                window = attacked_audio[i * step: i * step + window_size]
                window = window.unsqueeze(0)
                windows.append(window)
            windows = torch.cat(windows, dim=0)

            # print(windows.shape)

            # windows = windows.to(device)
            with torch.no_grad():
                outputs = model(windows)
            softmax_outputs = torch.nn.functional.softmax(outputs, dim=1)
            scores = torch.mean(softmax_outputs, dim=0)
            predicted = torch.argmax(scores)
            y_true.append(label.item())
            y_pred.append(predicted.item())

        y_true = np.array(y_true)
        y_pred = np.array(y_pred)

        real_indices = y_true == 0
        fake_indices = y_true == 1

        real_accuracy = accuracy_score(y_true[real_indices], y_pred[real_indices])
        fake_accuracy = accuracy_score(y_true[fake_indices], y_pred[fake_indices])

        attack_results_real[attack_freq] = real_accuracy
        attack_results_fake[attack_freq] = fake_accuracy

        print(f"Attack frequency: {attack_freq}, Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")

        # results = {"real": attack_results_real, "fake": attack_results_fake}
        # with open("lowfreq_attack_results.json", "w") as f:
        #     json.dump(results, f, indent=4)
    return attack_results_real, attack_results_fake


if __name__ == "__main__":
    attack_results_real, attack_results_fake = main()
