import torch
import json
import numpy as np
from model.RawNet3 import RawNet3_detect
from model.RawNetBasicBlock import Bottle2neck
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, roc_curve
import os
from torch.utils.data import Dataset
import librosa
from torch.nn import functional as F
import random
import torchaudio
from model.rawnet2.model import RawNet
import yaml

# import warnings
#
# warnings.filterwarnings("ignore", category=UserWarning)

device = "cuda" if torch.cuda.is_available() else "cpu"

upper_limit, lower_limit = 1, -1


def clamp(X, lower_limit, upper_limit):
    return torch.max(torch.min(X, upper_limit), lower_limit)


def attack_pgd(model, X, y, epsilon, alpha, attack_iters, restarts, norm):
    # batch normailzation, calculate mean and std for X: (B, L)

    max_loss = torch.zeros(y.shape[0])  # (B)
    max_delta = torch.zeros_like(X)  # (B, L)

    # device
    max_loss = max_loss.to(device)
    max_delta = max_delta.to(device)

    for _ in range(restarts):
        delta = torch.zeros_like(X)  # (B, L)
        if norm == "l_inf":
            delta.uniform_(-epsilon, epsilon)
        elif norm == "l_2":
            delta.normal_()
            d_flat = delta.view(delta.size(0), -1)
            n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1)
            r = torch.zeros_like(n).uniform_(0, 1)
            delta *= r / n * epsilon
        else:
            raise ValueError
        delta = clamp(delta, lower_limit - X, upper_limit - X)
        delta.requires_grad = True
        for _ in range(attack_iters):
            output = model(X + delta)
            index = slice(None, None, None)
            if not isinstance(index, slice) and len(index) == 0:
                break
            loss = F.cross_entropy(output, y)
            loss.backward()
            grad = delta.grad.detach()
            d = delta[index, :]
            g = grad[index, :]
            x = X[index, :]
            if norm == "l_inf":
                d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon)
            elif norm == "l_2":
                g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1)
                scaled_g = g / (g_norm + 1e-10)
                d = (d + scaled_g * alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d)
            d = clamp(d, lower_limit - x, upper_limit - x)
            delta.data[index, :] = d
            delta.grad.zero_()
            all_loss = F.cross_entropy(model(X + delta), y, reduction='none')
        max_delta[all_loss >= max_loss] = delta.detach()[all_loss >= max_loss]
        max_loss = torch.max(max_loss, all_loss)
    # dict
    return max_delta


class SpeechDataset(Dataset):
    def __init__(self, source_dir=None, fake_dir=None):
        self.data = []
        self.labels = []
        self.source_dir = source_dir
        self.fake_dir = fake_dir
        self.load_data()

    def load_data(self):
        if self.source_dir is not None:
            for dirpath, dirnames, filenames in os.walk(self.source_dir):
                for file in filenames:
                    if file.endswith(('.wav', '.mp3', '.flac', '.m4a')):
                        self.data.append(os.path.join(dirpath, file))
                        self.labels.append(0)  # 0 for real audio
        if self.fake_dir is not None:
            for dirpath, dirnames, filenames in os.walk(self.fake_dir):
                for file in filenames:
                    if file.endswith(('.wav', '.mp3', '.flac', '.m4a')):
                        self.data.append(os.path.join(dirpath, file))
                        self.labels.append(1)  # 1 for fake audio

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        try:
            wav_path = self.data[idx]
            # print(f"Processing file {wav_path}")
            audio, sample_rate = librosa.load(wav_path, sr=16000)
            if audio.ndim > 1:
                audio = librosa.to_mono(audio)

            # print(f"Processing audio with shape {audio.shape}")
        except:
            print(f"Error processing file {wav_path}")
            return None
        return audio, sample_rate, self.labels[idx]


def main():
    model_path = ".pth"
    attack_path = ".pth"

    source_dir = "./test_set/real_audio/"
    fake_dir = "./test_set/fake_audio/"

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load the model
    model = RawNet3_detect(encoder_type='ECA', nOut=256, sinc_stride=10, log_sinc=True, norm_sinc=True,
                           out_bn=True,
                           block=Bottle2neck, model_scale=8, context=True, summed=True)
    model.load_state_dict(torch.load(model_path)["model_state_dict"], strict=True)
    model = model.to(device)

    # load attack model
    attack_model = RawNet3_detect(encoder_type='ECA', nOut=256, sinc_stride=10, log_sinc=True, norm_sinc=True,
                                  out_bn=True,
                                  block=Bottle2neck, model_scale=8, context=True, summed=True)
    attack_model.load_state_dict(torch.load(attack_path)["model_state_dict"], strict=True)
    attack_model = attack_model.to(device)

    # dir_yaml = 'model/rawnet2/model_config_RawNet2.yaml'
    #
    # with open(dir_yaml, 'r') as f_yaml:
    #     parser1 = yaml.load(f_yaml, Loader=yaml.SafeLoader)
    #
    # attack_model = RawNet(parser1['model'], device='cuda').to('cuda')

    # Load the data
    dataset = SpeechDataset(source_dir=source_dir, fake_dir=fake_dir)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

    # Evaluate the model
    model.eval()
    attack_model.eval()

    # Attack parameters
    epsilon = 0.0005
    alpha = 0.0001
    restarts = 5
    attack_iters = 10
    norm = "l_2"

    y_true = []
    y_pred = []
    for i, (audio, sample_rate, label) in enumerate(dataloader):
        audio = audio.to(device)
        label = label.to(device)

        # audio length < 2400 * 160
        if audio.shape[1] > 2400 * 160:
            audio = audio[:, :2400 * 160]

        # Attack the audio
        max_delta = attack_pgd(attack_model, audio, label, epsilon, alpha, attack_iters, restarts, norm)

        # Evaluate the model on the adversarial audio
        attacked_audio = audio + max_delta
        # attacked_audio = audio
        attacked_audio = torch.clamp(attacked_audio, -1, 1).squeeze(0)
        # attacked_audio = attacked_audio.squeeze(0)

        # save the adversarial audio, 16k sample rate, 1%, use torchaudio
        # if random.random() < 0.01:
        #     os.makedirs("attack_results/pgd_step10_restart5_l2_A/", exist_ok=True)
        #     # save original audio
        #     torchaudio.save(f"attack_results/pgd_step10_restart5_l2_A/original_audio_{i}.wav", audio.cpu(), sample_rate=16000)
        #     torchaudio.save(f"attack_results/pgd_step10_restart5_l2_A/attacked_audio_{i}.wav", attacked_audio.unsqueeze(0).cpu(),
        #                     sample_rate=16000)

        # Compute the accuracy
        window_size = 600 * 160
        step = 300 * 160

        if attacked_audio.shape[0] < window_size:
            num_windows = 1

        else:
            num_windows = (attacked_audio.shape[0] - window_size) // step + 1
            num_windows = min(num_windows, 24)

        windows = []
        for i in range(num_windows):
            window = attacked_audio[i * step: i * step + window_size]
            window = window.unsqueeze(0)
            windows.append(window)
        windows = torch.cat(windows, dim=0)

        # print(windows.shape)

        with torch.no_grad():
            outputs = model(windows)
        softmax_outputs = torch.nn.functional.softmax(outputs, dim=1)
        scores = torch.mean(softmax_outputs, dim=0)
        predicted = torch.argmax(scores)
        y_true.append(label.item())
        y_pred.append(predicted.item())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    real_indices = y_true == 0
    fake_indices = y_true == 1

    real_accuracy = accuracy_score(y_true[real_indices], y_pred[real_indices])
    fake_accuracy = accuracy_score(y_true[fake_indices], y_pred[fake_indices])

    print("real_accuracy:", real_accuracy)
    print("fake_accuracy:", fake_accuracy)


if __name__ == "__main__":
    main()
