import torch
import numpy as np
from model.RawNet3 import RawNet3_detect
from model.RawNetBasicBlock import Bottle2neck
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, roc_curve
import os
from torch.utils.data import Dataset
import librosa
import socket
from hyperparameters import args


class SpeechValDataset(Dataset):
    def __init__(self, root_dir, source_dir_lists, fake_dir_list, speaker_ids, max_frames=600):
        self.data = []
        self.labels = []
        self.root_dir = root_dir
        self.source_dir_lists = source_dir_lists
        self.fake_dir_list = fake_dir_list
        self.speaker_ids = speaker_ids
        self.max_frames = max_frames
        self.load_data()

    def load_data(self):
        for source_dir in self.source_dir_lists:
            source_dir = os.path.join(self.root_dir, source_dir)
            if self.speaker_ids is None:
                self.speaker_ids = os.listdir(source_dir)
            for speaker in self.speaker_ids:
                speaker_path = os.path.join(source_dir, speaker)
                if os.path.isdir(speaker_path):
                    for wav_file in os.listdir(speaker_path):
                        if wav_file.endswith('.wav') or wav_file.endswith('.mp3') or wav_file.endswith('.flac'):
                            self.data.append(os.path.join(speaker_path, wav_file))
                            self.labels.append(0)  # 假设 real 类标签为 0
        for fake_dir in self.fake_dir_list:
            # import pdb; pdb.set_trace()
            fake_dir = os.path.join(self.root_dir, fake_dir)
            if self.speaker_ids is None:
                self.speaker_ids = os.listdir(fake_dir)
            for speaker in self.speaker_ids:
                speaker_path = os.path.join(fake_dir, speaker)
                if os.path.isdir(speaker_path):
                    for wav_file in os.listdir(speaker_path):
                        # import pdb; pdb.set_trace()
                        if wav_file.endswith('.wav') or wav_file.endswith('.mp3') or wav_file.endswith('.flac'):
                            self.data.append(os.path.join(speaker_path, wav_file))
                            self.labels.append(1)  # 假设 fake 类标签为 1

    def loadWAV(self, filepath, max_frames):

        # Maximum audio length
        max_audio = max_frames * 160

        # Read wav file, use librosa, and convert to mono
        # audio, sample_rate = sf.read(filepath)
        audio, sample_rate = librosa.load(filepath, sr=args.audio_sample_rate)
        if audio.ndim > 1:
            audio = librosa.to_mono(audio)

        audiosize = audio.shape[0]

        if audiosize <= max_audio:
            shortage = max_audio - audiosize + 1
            audio = np.pad(audio, (0, shortage), 'wrap')
            audiosize = audio.shape[0]

        startframe = np.int64(np.random.rand() * (audiosize - max_audio))
        audio = audio[startframe:startframe + max_audio]

        return audio, sample_rate

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        try:
            wav_path = self.data[idx]
            waveform, sample_rate = self.loadWAV(wav_path, self.max_frames)
        except:
            print(f"Error processing file {wav_path}")
            return None
        return torch.tensor(waveform, dtype=torch.float32), self.labels[idx], sample_rate


def load_model(model_path):
    model = RawNet3_detect(encoder_type='ECA', nOut=256, sinc_stride=10, log_sinc=True, norm_sinc=True,
                           out_bn=True,
                           block=Bottle2neck, model_scale=8, context=True, summed=True)
    model.load_state_dict(torch.load(model_path), strict=True)
    return model


def load_test_dataset(root_dir, source_dir_lists, fake_dir_list, speaker_annos):
    test_dataset = SpeechValDataset(root_dir, source_dir_lists, fake_dir_list, speaker_ids=None)
    return test_dataset


def compute_eer(y_true, y_scores):
    fpr, tpr, thresholds = roc_curve(y_true, y_scores, pos_label=1)
    fnr = 1 - tpr
    eer_threshold = thresholds[np.nanargmin(np.abs(fnr - fpr))]
    eer = fpr[np.nanargmin(np.abs(fnr - fpr))]
    return eer, eer_threshold


def test(model, test_loader, device):
    model.eval()
    all_labels = []
    all_predictions = []
    all_scores = []
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            inputs, labels, sample_rate = data
            inputs = inputs[:, :args.max_frames * 160 * 24]
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs, eval=True)
            scores = torch.nn.functional.softmax(outputs, dim=1)[:, 1]
            _, predicted = torch.max(outputs.data, 1)

            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())
            all_scores.extend(scores.cpu().numpy())

    all_labels = np.array(all_labels)
    all_predictions = np.array(all_predictions)
    all_scores = np.array(all_scores)

    return all_labels, all_predictions, all_scores


def test_intrain(root_dir, model):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    source_dir_lists = ["test_dataset/Custom-Audio-Dataset/Real Audio"]
    fake_dir_list = ["test_dataset/Custom-Audio-Dataset/Deepfake Audio",
                     "test_dataset/Deepfake-playHT/",
                     "test_dataset/elevenlabs_testo",
                     "test_dataset/Celebrity_deepfakes/",
                     "test_dataset/Resemble 2",
                     "test_dataset/football_stars"]
    # separately test on each dataset

    labels = []
    predictions = []
    scores = []

    # log acc and eer
    log_file = f"{args.ckp_dir}/log.txt"

    for real_dir in source_dir_lists:
        test_dataset = load_test_dataset(root_dir, [real_dir], [], speaker_annos=None)
        test_loader = DataLoader(test_dataset, batch_size=64)

        all_labels, all_predictions, all_scores = test(model, test_loader, device)
        # calculate the acc for this dataset
        accuracy = accuracy_score(all_labels, all_predictions)
        print(f"Accuracy for {real_dir}: {accuracy:.4f}")

        with open(log_file, "a") as f:
            f.write(f"Accuracy for {real_dir}: {accuracy:.4f}\n")

        labels.extend(all_labels)
        predictions.extend(all_predictions)
        scores.extend(all_scores)

    for fake_dir in fake_dir_list:
        test_dataset = load_test_dataset(root_dir, [], [fake_dir], speaker_annos=None)
        test_loader = DataLoader(test_dataset, batch_size=32)
        all_labels, all_predictions, all_scores = test(model, test_loader, device)
        # calculate the acc for this dataset
        accuracy = accuracy_score(all_labels, all_predictions)
        print(f"Accuracy for {fake_dir}: {accuracy:.4f}")

        with open(log_file, "a") as f:
            f.write(f"Accuracy for {fake_dir}: {accuracy:.4f}\n")

        labels.extend(all_labels)
        predictions.extend(all_predictions)
        scores.extend(all_scores)

    # calculate the overall acc
    accuracy = accuracy_score(labels, predictions)
    print(f"Overall Accuracy: {accuracy:.4f}")
    # calculate the overall eer
    try:
        eer, eer_threshold = compute_eer(labels, scores)
        print(f"Equal Error Rate (EER): {eer:.4f} at threshold {eer_threshold:.4f}")
    except:
        print('EER error')

    with open(log_file, "a") as f:
        f.write(f"Overall Accuracy without window: {accuracy:.4f}\n")
        f.write(f"Equal Error Rate (EER) without window: {eer:.4f} at threshold {eer_threshold:.4f}\n")

    return accuracy, eer


def main():
    model_path = ""
    model = load_model(model_path)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    root_dir = ""
    test_intrain(root_dir, model)


if __name__ == "__main__":
    main()
