import torch
import json
import numpy as np
from model.RawNet3 import RawNet3_detect
from model.RawNetBasicBlock import Bottle2neck
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, roc_curve
import os
from torch.utils.data import Dataset
import librosa
from audiomentations import *

import warnings

warnings.filterwarnings("ignore", category=UserWarning)


class SpeechDataset(Dataset):
    def __init__(self, source_dir=None, fake_dir=None):
        self.data = []
        self.labels = []
        self.source_dir = source_dir
        self.fake_dir = fake_dir
        self.load_data()

    def load_data(self):
        if self.source_dir is not None:
            for dirpath, dirnames, filenames in os.walk(self.source_dir):
                for file in filenames:
                    if file.endswith(('.wav', '.mp3', '.flac', '.m4a')):
                        self.data.append(os.path.join(dirpath, file))
                        self.labels.append(0)  # 0 for real audio
        if self.fake_dir is not None:
            for dirpath, dirnames, filenames in os.walk(self.fake_dir):
                for file in filenames:
                    if file.endswith(('.wav', '.mp3', '.flac', '.m4a')):
                        self.data.append(os.path.join(dirpath, file))
                        self.labels.append(1)  # 1 for fake audio

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        try:
            wav_path = self.data[idx]
            # print(f"Processing file {wav_path}")
            audio, sample_rate = librosa.load(wav_path, sr=None)
            if audio.ndim > 1:
                audio = librosa.to_mono(audio)

            # print(f"Processing audio with shape {audio.shape}")
        except:
            print(f"Error processing file {wav_path}")
            return None
        return audio, sample_rate, self.labels[idx]


def eval_corrupt(model, dataloader, device, type="normal", **kwargs):
    y_true = []
    y_pred = []

    assert type in ["normal", "aliasing", "background_noise", "color_noise", "gaussian_snr",
                    "downsample", "add_short_noises", "air_absorption", "bandpass_filter",
                    "bandstop_filter", "bitcrush", "clip", "clipping_distortion", "gain",
                    "gain_transition", "highpass_filter", "highshelf_filter", "limiter",
                    "lowpass_filter", "lowshelf_filter", "mp3_compression", "peaking_filter",
                    "polarity_inversion", "repeat_part", "reverse", "room_simulator",
                    "seven_band_parametric_eq", "shift", "tanh_distortion", "time_mask",
                    "time_stretch", "trim"], "Unsupported corruption type"
    print(f"Evaluating {type} corrupted data")

    # Define the dictionary with all possible augmentations
    augmentations = {
        'background_noise': AddBackgroundNoise(sounds_path=kwargs.get("noise_files", "/local/rcs/zz3093/noise_v1"),
                                               min_snr_in_db=kwargs.get("background_noise_snr", 3.0),
                                               max_snr_in_db=kwargs.get("background_noise_snr", 30.0), p=1),

        'color_noise': AddColorNoise(p=1, min_snr_db=kwargs.get("color_noise_snr", 5.0),
                                     max_snr_db=kwargs.get("color_noise_snr", 40.0),
                                     min_f_decay=kwargs.get("color_noise_f_decay", -6.0),
                                     max_f_decay=kwargs.get("color_noise_f_decay", 6.0)),

        'aliasing': Aliasing(min_sample_rate=kwargs.get("aliasing_sample_rate", 4000),
                             max_sample_rate=kwargs.get("aliasing_sample_rate", 16000), p=1),

        'gaussian_snr': AddGaussianSNR(min_snr_db=kwargs.get("gaussian_snr", 5.0),
                                       max_snr_db=kwargs.get("gaussian_snr", 40.0), p=1),

        'add_short_noises': AddShortNoises(sounds_path=kwargs.get("noise_files", "/local/rcs/zz3093/noise_v1"), p=1,
                                           min_snr_in_db=kwargs.get("short_noises_snr", 3.0),
                                           max_snr_in_db=kwargs.get("short_noises_snr", 30.0),
                                           min_time_between_sounds=kwargs.get("short_noises_time_between_sounds", 1.0),
                                           max_time_between_sounds=kwargs.get("short_noises_time_between_sounds", 4.0)),

        'air_absorption': AirAbsorption(p=1, min_distance=kwargs.get("air_absorption_distance", 10.0),
                                        max_distance=kwargs.get("air_absorption_distance", 100.0),
                                        min_temperature=kwargs.get("air_absorption_temperature", 10.0),
                                        max_temperature=kwargs.get("air_absorption_temperature", 20.0),
                                        min_humidity=kwargs.get("air_absorption_humidity", 30.0),
                                        max_humidity=kwargs.get("air_absorption_humidity", 90.0)),

        'bandpass_filter': BandPassFilter(min_center_freq=kwargs.get("bandpass_freq", 200),
                                          max_center_freq=kwargs.get("bandpass_freq", 4000),
                                          min_bandwidth_fraction=kwargs.get("bandpass_bandwidth_fraction", 0.5),
                                          max_bandwidth_fraction=kwargs.get("bandpass_bandwidth_fraction", 1.99),
                                          min_rolloff=kwargs.get("bandpass_rolloff", 12),
                                          max_rolloff=kwargs.get("bandpass_rolloff", 24),
                                          p=1.0),

        'bandstop_filter': BandStopFilter(min_center_freq=kwargs.get("bandstop_freq", 200),
                                          max_center_freq=kwargs.get("bandstop_freq", 4000),
                                          min_bandwidth_fraction=kwargs.get("bandstop_bandwidth_fraction", 0.5),
                                          max_bandwidth_fraction=kwargs.get("bandstop_bandwidth_fraction", 1.99),
                                          min_rolloff=kwargs.get("bandstop_rolloff", 12),
                                          max_rolloff=kwargs.get("bandstop_rolloff", 24),
                                          p=1.0),

        'bitcrush': BitCrush(p=1.0,
                             min_bit_depth=kwargs.get("bitcrush_bit_depth", 5),
                             max_bit_depth=kwargs.get("bitcrush_bit_depth", 14)),

        'clip': Clip(a_min=kwargs.get("clip_boundary", -1.0),
                     a_max=-1 * kwargs.get("clip_boundary", -1.0), p=1.0),

        'clipping_distortion': ClippingDistortion(min_percentile_threshold=kwargs.get("clipping_threshold", 0),
                                                  max_percentile_threshold=kwargs.get("clipping_threshold", 40), p=1.0),

        'gain': Gain(min_gain_db=kwargs.get("gain_db", -12.0), max_gain_db=kwargs.get("gain_db", 12.0), p=1.0),

        'gain_transition': GainTransition(p=1.0, min_duration=kwargs.get("gain_transition_duration", 0.2),
                                          max_duration=kwargs.get("gain_transition_duration", 6.0),
                                          min_gain_db=kwargs.get("gain_transition_db", -24.0),
                                          max_gain_db=kwargs.get("gain_transition_db", 6.0),
                                          duration_unit="seconds"),

        'highpass_filter': HighPassFilter(p=1.0, min_cutoff_freq=kwargs.get("highpass_freq", 20.0),
                                          max_cutoff_freq=kwargs.get("highpass_freq", 2400.0),
                                          min_rolloff=kwargs.get("highpass_rolloff", 12),
                                          max_rolloff=kwargs.get("highpass_rolloff", 24)),

        'highshelf_filter': HighShelfFilter(p=1.0, min_gain_db=kwargs.get("highshelf_gain", -18.0),
                                            max_gain_db=kwargs.get("highshelf_gain", 18.0),
                                            min_center_freq=kwargs.get("highshelf_freq", 300.0),
                                            max_center_freq=kwargs.get("highshelf_freq", 7500.0),
                                            min_q=kwargs.get("highshelf_q", 0.1),
                                            max_q=kwargs.get("highshelf_q", 0.999)),

        'limiter': Limiter(p=1.0, min_threshold_db=kwargs.get("limiter_threshold", -24.0),
                           max_threshold_db=kwargs.get("limiter_threshold", -2.0),
                           min_attack=kwargs.get("limiter_attack", 0.0005),
                           max_attack=kwargs.get("limiter_attack", 0.005),
                           min_release=kwargs.get("limiter_release", 0.05),
                           max_release=kwargs.get("limiter_release", 0.7)),

        'lowpass_filter': LowPassFilter(p=1.0,
                                        min_cutoff_freq=kwargs.get("lowpass_freq", 150.0),
                                        max_cutoff_freq=kwargs.get("lowpass_freq", 7500.0),
                                        min_rolloff=kwargs.get("lowpass_rolloff", 12),
                                        max_rolloff=kwargs.get("lowpass_rolloff", 24)),

        'lowshelf_filter': LowShelfFilter(p=1.0,
                                          min_center_freq=kwargs.get("lowshelf_freq", 300.0),
                                          max_center_freq=kwargs.get("lowshelf_freq", 7500.0),
                                          min_gain_db=kwargs.get("lowshelf_gain", -18.0),
                                          max_gain_db=kwargs.get("lowshelf_gain", 18.0),
                                          min_q=kwargs.get("lowshelf_q", 0.1),
                                          max_q=kwargs.get("lowshelf_q", 0.999)),

        'peaking_filter': PeakingFilter(p=1.0,
                                        min_center_freq=kwargs.get("peaking_freq", 50.0),
                                        max_center_freq=kwargs.get("peaking_freq", 7500.0),
                                        min_gain_db=kwargs.get("peaking_gain", -24.0),
                                        max_gain_db=kwargs.get("peaking_gain", 24.0),
                                        min_q=kwargs.get("peaking_q", 0.5),
                                        max_q=kwargs.get("peaking_q", 5.0)),

        'polarity_inversion': PolarityInversion(p=1.0),

        'room_simulator': RoomSimulator(p=1.0),

        'seven_band_parametric_eq': SevenBandParametricEQ(p=1.0,
                                                          min_gain_db=kwargs.get("seven_band_gain", -12.0),
                                                          max_gain_db=kwargs.get("seven_band_gain", 12.0)),
        'tanh_distortion': TanhDistortion(p=1.0,
                                          min_distortion=kwargs.get("tanh_distortion_gain", 0.01),
                                          max_distortion=kwargs.get("tanh_distortion_gain", 0.7)),
        'time_mask': TimeMask(p=1.0,
                              min_band_part=kwargs.get("time_mask_band_part_width", 0),
                              max_band_part=kwargs.get("time_mask_band_part_width", 0.5),
                              fade=kwargs.get("time_mask_fade", False)),

        'time_stretch': TimeStretch(p=1.0, min_rate=kwargs.get("time_stretch_rate", 0.5),
                                    max_rate=kwargs.get("time_stretch_rate", 2)),

        # 'mp3_compression': Mp3Compression(p=1.0),
        # 'trim': Trim(p=1.0),
        # 'shift': Shift(p=1.0),
        # 'repeat_part': RepeatPart(p=1.0),
        # 'reverse': Reverse(p=1.0),
    }

    if type != "normal":
        augment = augmentations[type]

    for data in dataloader:
        # convert to numpy array
        audio, sr, label = data
        audio = audio.numpy().squeeze()

        if type == "aliasing" and "aliasing_sample_rate" in kwargs:
            sr = kwargs["aliasing_sample_rate"]
        elif type == "aliasing" and "aliasing_sample_rate" not in kwargs:
            assert False, "aliasing corruption requires sample_rate argument"

        if sr != 16000:
            audio = librosa.resample(y=audio, orig_sr=sr, target_sr=16000)

        if type != "normal":
            audio = augment(samples=audio, sample_rate=16000)

        audio = torch.tensor(audio, dtype=torch.float32).to(
            device)  # Adjust tensor dimensions if necessary

        window_size = 600 * 160
        step = 300 * 160

        if audio.shape[0] < window_size:
            num_windows = 1

        else:
            num_windows = (audio.shape[0] - window_size) // step + 1
            num_windows = min(num_windows, 24)

        windows = []
        for i in range(num_windows):
            window = audio[i * step: i * step + window_size]
            window = window.unsqueeze(0)
            windows.append(window)
        windows = torch.cat(windows, dim=0)

        with torch.no_grad():
            outputs = model(windows)
        softmax_outputs = torch.nn.functional.softmax(outputs, dim=1)
        scores = torch.mean(softmax_outputs, dim=0)
        predicted = torch.argmax(scores)
        y_true.append(label.item())
        y_pred.append(predicted.item())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    real_indices = y_true == 0
    fake_indices = y_true == 1

    real_accuracy = accuracy_score(y_true[real_indices], y_pred[real_indices])
    fake_accuracy = accuracy_score(y_true[fake_indices], y_pred[fake_indices])

    return real_accuracy, fake_accuracy


def main():
    model_path = "/epoch_11.pth"
    source_dir = "./test_set/real_audio/"
    fake_dir = "./test_set/fake_audio/"

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load the model
    model = RawNet3_detect(encoder_type='ECA', nOut=256, sinc_stride=10, log_sinc=True, norm_sinc=True,
                           out_bn=True,
                           block=Bottle2neck, model_scale=8, context=True, summed=True)
    model.load_state_dict(torch.load(model_path)["model_state_dict"], strict=True)
    model = model.to(device)

    # Load the data
    dataset = SpeechDataset(source_dir=source_dir, fake_dir=fake_dir)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

    # Evaluate the model
    model.eval()

    """
    evaluate the original data
    """
    origin_acc = {"real": None, "fake": None}
    real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="normal")
    print(f"Original data: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
    origin_acc["real"] = real_accuracy
    origin_acc["fake"] = fake_accuracy

    """
    evaluate the attacked data
    """

    """
    Downsample
    """

    downsample_acc = {"real": {}, "fake": {}}
    for sample_rate in range(4000, 16000, 1000):
        real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="aliasing",
                                                    aliasing_sample_rate=sample_rate)
        print(f"Downsample to {sample_rate} Hz: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
        downsample_acc["real"][sample_rate] = real_accuracy
        downsample_acc["fake"][sample_rate] = fake_accuracy

    """
    Noise
    """

    # Noise
    # add Background noise
    noise_acc = {"real": {}, "fake": {}}
    for snr in range(1, 31, 2):
        real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="background_noise",
                                                    noise_files="/local/rcs/zz3093/noise_v1",
                                                    background_noise_snr=snr)
        print(f"Add background noise with SNR {snr} dB: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
        noise_acc["real"][snr] = real_accuracy
        noise_acc["fake"][snr] = fake_accuracy

    # add color noise
    color_noise_acc = {"real": {}, "fake": {}}
    # pink - 3.01 dB/octave, brown - 6.02 dB/octave, white 0.0 dB/octave, blue 3.01 dB/octave, violet 6.02 dB/octave
    for f_decay in [-6.02, -3.01, 0.0, 3.01, 6.02]:
        snr = 10
        real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="color_noise",
                                                    color_noise_snr=snr, color_noise_f_decay=f_decay)
        print(
            f"Add color noise with decay {f_decay} dB/octave: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
        if f_decay not in color_noise_acc["real"]:
            color_noise_acc["real"][f_decay] = {}
            color_noise_acc["fake"][f_decay] = {}
        color_noise_acc["real"][f_decay][snr] = real_accuracy
        color_noise_acc["fake"][f_decay][snr] = fake_accuracy

    # add gaussian snr
    gaussian_snr_acc = {"real": {}, "fake": {}}
    for snr in range(1, 31, 2):
        real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="gaussian_snr",
                                                    gaussian_snr=snr)
        print(f"Add gaussian SNR with SNR {snr} dB: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
        gaussian_snr_acc["real"][snr] = real_accuracy
        gaussian_snr_acc["fake"][snr] = fake_accuracy

    # add short noises
    short_noises_acc = {"real": {}, "fake": {}}
    snr = 10
    for time_between_sounds in range(1, 5, 1):
        real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="add_short_noises",
                                                    short_noises_snr=snr,
                                                    short_noises_time_between_sounds=time_between_sounds)
        print(
            f"Add short noises with time between sounds {time_between_sounds} s: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
        short_noises_acc["real"][time_between_sounds] = real_accuracy
        short_noises_acc["fake"][time_between_sounds] = fake_accuracy

    # add air absorption
    air_absorption_acc = {"real": {}, "fake": {}}
    temperature = 20
    humidity = 60
    for distance in range(10, 101, 5):
        real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="air_absorption",
                                                    air_absorption_distance=distance,
                                                    air_absorption_temperature=temperature,
                                                    air_absorption_humidity=humidity)
        print(
            f"Add air absorption with distance {distance} m: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
        air_absorption_acc["real"][distance] = real_accuracy
        air_absorption_acc["fake"][distance] = fake_accuracy

    # add bandpass filter
    bandpass_filter_acc = {"real": {}, "fake": {}}
    bandpass_bandwidth_fraction = 1
    bandpass_rolloff = 18
    for freq in range(200, 4001, 200):
        real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="bandpass_filter",
                                                    bandpass_freq=freq,
                                                    bandpass_bandwidth_fraction=bandpass_bandwidth_fraction,
                                                    bandpass_rolloff=bandpass_rolloff)
        print(
            f"Add bandpass filter with frequency {freq} Hz: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
        bandpass_filter_acc["real"][freq] = real_accuracy
        bandpass_filter_acc["fake"][freq] = fake_accuracy

    # add bandstop filter
    bandstop_filter_acc = {"real": {}, "fake": {}}
    bandstop_bandwidth_fraction = 1
    bandstop_rolloff = 18
    for freq in range(200, 4001, 200):
        real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="bandstop_filter",
                                                    bandstop_freq=freq,
                                                    bandstop_bandwidth_fraction=bandstop_bandwidth_fraction,
                                                    bandstop_rolloff=bandstop_rolloff)
        print(
            f"Add bandstop filter with frequency {freq} Hz: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
        bandstop_filter_acc["real"][freq] = real_accuracy
        bandstop_filter_acc["fake"][freq] = fake_accuracy

    # add bitcrush
    bitcrush_acc = {"real": {}, "fake": {}}
    for bit_depth in range(5, 15, 1):
        real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="bitcrush",
                                                    bitcrush_bit_depth=bit_depth)
        print(
            f"Add bitcrush with bit depth {bit_depth}: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
        bitcrush_acc["real"][bit_depth] = real_accuracy
        bitcrush_acc["fake"][bit_depth] = fake_accuracy

    # add clip
    clip_acc = {"real": {}, "fake": {}}
    for boundary in [-0.5, -0.4, -0.3, -0.2, -0.1]:
        real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="clip",
                                                    clip_boundary=boundary)
        print(f"Add clip with boundary {boundary}: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
        clip_acc["real"][boundary] = real_accuracy
        clip_acc["fake"][boundary] = fake_accuracy

    # add clipping distortion
    clipping_distortion_acc = {"real": {}, "fake": {}}
    for threshold in range(0, 41, 2):
        real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="clipping_distortion",
                                                    clipping_threshold=threshold)
        print(
            f"Add clipping distortion with threshold {threshold}: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
        clipping_distortion_acc["real"][threshold] = real_accuracy
        clipping_distortion_acc["fake"][threshold] = fake_accuracy

    # add gain
    gain_acc = {"real": {}, "fake": {}}
    for gain in range(-12, 13, 1):
        real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="gain",
                                                    gain_db=gain)
        print(f"Add gain with gain {gain} dB: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
        gain_acc["real"][gain] = real_accuracy
        gain_acc["fake"][gain] = fake_accuracy

    # add gain transition
    db = 3.0
    gain_transition_acc = {"real": {}, "fake": {}}
    for duration in range(2, 31, 1):
        duration = duration / 10
        real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="gain_transition",
                                                    gain_transition_duration=duration,
                                                    gain_transition_db=db)
        print(
            f"Add gain transition with duration {duration} s: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
        gain_transition_acc["real"][duration] = real_accuracy
        gain_transition_acc["fake"][duration] = fake_accuracy

    # add highpass filter
    highpass_filter_acc = {"real": {}, "fake": {}}
    rolloff = 18
    for freq in range(20, 2401, 100):
        real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="highpass_filter",
                                                    highpass_freq=freq,
                                                    highpass_rolloff=rolloff)
        print(
            f"Add highpass filter with frequency {freq} Hz: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
        highpass_filter_acc["real"][freq] = real_accuracy
        highpass_filter_acc["fake"][freq] = fake_accuracy

    # add highshelf filter
    highshelf_filter_acc = {"real": {}, "fake": {}}
    gain = 0
    q = 0.5
    for freq in range(300, 7501, 500):
        real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="highshelf_filter",
                                                    highshelf_gain=gain,
                                                    highshelf_freq=freq,
                                                    highshelf_q=q)
        print(
            f"Add highshelf filter with frequency {freq} Hz: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
        highshelf_filter_acc["real"][freq] = real_accuracy
        highshelf_filter_acc["fake"][freq] = fake_accuracy

    # add limiter
    limiter_acc = {"real": {}, "fake": {}}
    attack = 0.015
    release = 0.3
    for threshold in range(-24, -2, 10):
        real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="limiter",
                                                    limiter_threshold=threshold,
                                                    limiter_attack=attack,
                                                    limiter_release=release)
        print(
            f"Add limiter with threshold {threshold} dB: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
        limiter_acc["real"][threshold] = real_accuracy
        limiter_acc["fake"][threshold] = fake_accuracy

    # add lowpass filter
    lowpass_filter_acc = {"real": {}, "fake": {}}
    rolloff = 18
    for freq in range(150, 7501, 500):
        real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="lowpass_filter",
                                                    lowpass_freq=freq,
                                                    lowpass_rolloff=rolloff)
        print(
            f"Add lowpass filter with frequency {freq} Hz: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
        lowpass_filter_acc["real"][freq] = real_accuracy
        lowpass_filter_acc["fake"][freq] = fake_accuracy

    # add lowshelf filter
    lowshelf_filter_acc = {"real": {}, "fake": {}}
    gain = 6
    q = 0.5
    for freq in range(300, 7501, 500):
        real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="lowshelf_filter",
                                                    lowshelf_gain=gain,
                                                    lowshelf_freq=freq,
                                                    lowshelf_q=q)
        print(
            f"Add lowshelf filter with frequency {freq} Hz: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
        lowshelf_filter_acc["real"][freq] = real_accuracy
        lowshelf_filter_acc["fake"][freq] = fake_accuracy

    # add peaking filter
    peaking_filter_acc = {"real": {}, "fake": {}}
    gain = 6
    q = 1.0
    for freq in range(50, 7501, 500):
        real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="peaking_filter",
                                                    peaking_gain=gain,
                                                    peaking_freq=freq,
                                                    peaking_q=q)
        print(
            f"Add peaking filter with frequency {freq} Hz: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
        peaking_filter_acc["real"][freq] = real_accuracy
        peaking_filter_acc["fake"][freq] = fake_accuracy

    # add polarity inversion
    polarity_inversion_acc = {"real": {}, "fake": {}}
    real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="polarity_inversion")
    print(f"Add polarity inversion: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
    polarity_inversion_acc["real"] = real_accuracy
    polarity_inversion_acc["fake"] = fake_accuracy

    # add room simulator
    room_simulator_acc = {"real": {}, "fake": {}}
    real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="room_simulator")
    print(f"Add room simulator: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
    room_simulator_acc["real"] = real_accuracy
    room_simulator_acc["fake"] = fake_accuracy

    # add seven band parametric eq
    seven_band_parametric_eq_acc = {"real": {}, "fake": {}}
    for db in range(-12, 13, 1):
        real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="seven_band_parametric_eq",
                                                    seven_band_gain=db)
        print(
            f"Add seven band parametric eq with gain {db} dB: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
        seven_band_parametric_eq_acc["real"][db] = real_accuracy
        seven_band_parametric_eq_acc["fake"][db] = fake_accuracy

    # add tanh distortion
    tanh_distortion_acc = {"real": {}, "fake": {}}
    for distortion in np.arange(0.01, 0.71, 0.03):
        real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="tanh_distortion",
                                                    tanh_distortion_gain=distortion)
        print(
            f"Add tanh distortion with distortion {distortion}: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
        tanh_distortion_acc["real"][distortion] = real_accuracy
        tanh_distortion_acc["fake"][distortion] = fake_accuracy

    # add time mask
    time_mask_acc = {"real": {}, "fake": {}}
    for band_part in np.arange(0, 0.51, 0.025):
        real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="time_mask",
                                                    time_mask_band_part_width=band_part)
        print(
            f"Add time mask with band part width {band_part}: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
        time_mask_acc["real"][band_part] = real_accuracy
        time_mask_acc["fake"][band_part] = fake_accuracy

    # add time stretch
    time_stretch_acc = {"real": {}, "fake": {}}
    for rate in np.arange(0.5, 2.1, 0.05):
        real_accuracy, fake_accuracy = eval_corrupt(model, dataloader, device, type="time_stretch",
                                                    time_stretch_rate=rate)
        print(
            f"Add time stretch with rate {rate}: Real accuracy: {real_accuracy}, Fake accuracy: {fake_accuracy}")
        time_stretch_acc["real"][rate] = real_accuracy
        time_stretch_acc["fake"][rate] = fake_accuracy

    # save all results
    results = {
        "origin": origin_acc,
        "downsample": downsample_acc,
        "noise": noise_acc,
        "color_noise": color_noise_acc,
        "gaussian_snr": gaussian_snr_acc,
        "short_noises": short_noises_acc,
        "air_absorption": air_absorption_acc,
        "bandpass_filter": bandpass_filter_acc,
        "bandstop_filter": bandstop_filter_acc,
        "bitcrush": bitcrush_acc,
        "clip": clip_acc,
        "clipping_distortion": clipping_distortion_acc,
        "gain": gain_acc,
        "gain_transition": gain_transition_acc,
        "highpass_filter": highpass_filter_acc,
        "highshelf_filter": highshelf_filter_acc,
        "limiter": limiter_acc,
        "lowpass_filter": lowpass_filter_acc,
        "lowshelf_filter": lowshelf_filter_acc,
        "peaking_filter": peaking_filter_acc,
        "polarity_inversion": polarity_inversion_acc,
        "room_simulator": room_simulator_acc,
        "seven_band_parametric_eq": seven_band_parametric_eq_acc,
        "tanh_distortion": tanh_distortion_acc,
        "time_mask": time_mask_acc,
        "time_stretch": time_stretch_acc
    }

    with open("results.json", "w") as f:
        json.dump(results, f, indent=4)


if __name__ == '__main__':
    main()
