import torch
import json
import numpy as np
from model.RawNet3 import RawNet3_detect as RawNet3_detect
from model.RawNetBasicBlock import Bottle2neck
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, roc_curve
import os
from torch.utils.data import Dataset
import librosa
from hyperparameters import args
import socket


class SpeechValDataset(Dataset):
    def __init__(self, root_dir, source_dir_lists, fake_dir_list, speaker_ids, max_frames=600):
        self.data = []
        self.labels = []
        self.root_dir = root_dir
        self.source_dir_lists = source_dir_lists
        self.fake_dir_list = fake_dir_list
        self.speaker_ids = speaker_ids
        self.max_frames = max_frames
        self.load_data()

    def load_data(self):
        for source_dir in self.source_dir_lists:
            source_dir = os.path.join(self.root_dir, source_dir)
            if self.speaker_ids is None:
                self.speaker_ids = os.listdir(source_dir)
            for speaker in self.speaker_ids:
                speaker_path = os.path.join(source_dir, speaker)
                if os.path.isdir(speaker_path):
                    for wav_file in os.listdir(speaker_path):
                        if wav_file.endswith('.wav') or wav_file.endswith('.mp3') or wav_file.endswith(
                                '.flac') or wav_file.endswith('.m4a'):
                            self.data.append(os.path.join(speaker_path, wav_file))
                            self.labels.append(0)
        for fake_dir in self.fake_dir_list:
            fake_dir = os.path.join(self.root_dir, fake_dir)
            if self.speaker_ids is None:
                self.speaker_ids = os.listdir(fake_dir)
            for speaker in self.speaker_ids:
                speaker_path = os.path.join(fake_dir, speaker)
                if os.path.isdir(speaker_path):
                    for wav_file in os.listdir(speaker_path):
                        if wav_file.endswith('.wav') or wav_file.endswith('.mp3') or wav_file.endswith(
                                '.flac') or wav_file.endswith('.m4a'):
                            self.data.append(os.path.join(speaker_path, wav_file))
                            self.labels.append(1)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        try:
            wav_path = self.data[idx]
            # audio, sr = sf.read(wav_path)
            # sample_rate = 16000
            # if sr != 16000:
            #     audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
            audio, sample_rate = librosa.load(wav_path, sr=args.audio_sample_rate)

            # convert to mono
            if audio.ndim > 1:
                audio = librosa.to_mono(audio)

            # time_stretch
            # audio = librosa.effects.time_stretch(audio, rate=0.5)

            # pitch_shift
            # audio = librosa.effects.pitch_shift(audio, sr=16000, n_steps=2)

            # downsample
            # audio = librosa.resample(audio, orig_sr=16000, target_sr=8000)

            # Volume_Adjustment
            # audio = audio * 2

            # compress_audio
            # audio = compress_audio(audio)

            # Maximum audio length
            max_audio = args.max_frames * 160

            # warp if the audio is too short
            audiosize = audio.shape[0]
            if audiosize <= max_audio:
                shortage = max_audio - audiosize + 1
                audio = np.pad(audio, (0, shortage), 'wrap')
        except:
            print(f"Error processing file {wav_path}")
            return None
        return torch.tensor(audio, dtype=torch.float32), self.labels[idx], sample_rate


def load_model(model_path):
    model = RawNet3_detect(encoder_type='ECA', nOut=256, sinc_stride=10, log_sinc=True, norm_sinc=True,
                           out_bn=True,
                           block=Bottle2neck, model_scale=8, context=True, summed=True)
    model.load_state_dict(torch.load(model_path)["model_state_dict"], strict=True)
    return model


def load_test_dataset(root_dir, source_dir_lists, fake_dir_list, speaker_annos):
    test_dataset = SpeechValDataset(root_dir, source_dir_lists, fake_dir_list, speaker_ids=None)
    return test_dataset


def compute_eer(y_true, y_scores):
    fpr, tpr, thresholds = roc_curve(y_true, y_scores, pos_label=1)
    fnr = 1 - tpr
    eer_threshold = thresholds[np.nanargmin(np.abs(fnr - fpr))]
    eer = fpr[np.nanargmin(np.abs(fnr - fpr))]
    return eer, eer_threshold


def test(model, test_loader, device):
    model.eval()
    all_labels = []
    all_predictions = []
    all_scores = []
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            inputs, labels, sample_rate = data
            inputs, labels = inputs.to(device), labels.to(device)
            inputs = inputs[:, :args.max_frames * 160 * 24]
            assert inputs.shape[0] == 1, "Only support batch size 1"
            audio = inputs.squeeze()
            # use sliding window to get the prediction
            window_size = args.max_frames * 160
            step = args.step_size * 160

            # get the number of windows, the last window may be shorter, we will ignore it
            num_windows = (audio.shape[0] - window_size) // step + 1
            num_windows = min(num_windows, 24)

            windows = []
            for i in range(num_windows):
                window = audio[i * step: i * step + window_size]
                window = window.unsqueeze(0)
                windows.append(window)

            windows = torch.cat(windows, dim=0)
            outputs = model(windows, eval=True)

            # Apply softmax to each window's output
            softmax_outputs = torch.nn.functional.softmax(outputs, dim=1)

            # Average the softmax scores across all windows
            scores = torch.mean(softmax_outputs, dim=0)

            # EER
            all_labels.append(labels.item())
            all_scores.append(scores[1].item())
            all_predictions.append(1 if scores[1] > scores[0] else 0)

    return all_labels, all_predictions, all_scores


def testW_intrain(root_dir, model):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    source_dir_lists = ["test_dataset/Custom-Audio-Dataset/Real Audio"]
    fake_dir_list = ["test_dataset/Custom-Audio-Dataset/Deepfake Audio",
                     "test_dataset/Deepfake-playHT/",
                     "test_dataset/elevenlabs_testo",
                     "test_dataset/Celebrity_deepfakes/",
                     "test_dataset/Resemble 2",
                     "test_dataset/football_stars"]

    # separately test on each dataset

    labels = []
    predictions = []
    scores = []
    # log acc and eer
    log_file = f"{args.ckp_dir}/log.txt"

    for real_dir in source_dir_lists:
        print(f"Testing on {real_dir}")
        test_dataset = load_test_dataset(root_dir, [real_dir], [], speaker_annos=None)
        # only get 20% of the data, randomly
        test_dataset = torch.utils.data.Subset(test_dataset,
                                               np.random.choice(len(test_dataset), int(len(test_dataset) * 0.2),
                                                                replace=False))

        test_loader = DataLoader(test_dataset, batch_size=1)
        all_labels, all_predictions, all_scores = test(model, test_loader, device)
        # calculate the acc for this dataset
        accuracy = accuracy_score(all_labels, all_predictions)
        print(f"Accuracy for {real_dir}: {accuracy:.4f}")

        with open(log_file, "a") as f:
            f.write(f"Accuracy for {real_dir}: {accuracy:.4f}\n")

        labels.extend(all_labels)
        predictions.extend(all_predictions)
        scores.extend(all_scores)

    for fake_dir in fake_dir_list:
        print(f"Testing on {fake_dir}")
        test_dataset = load_test_dataset(root_dir, [], [fake_dir], speaker_annos=None)
        test_loader = DataLoader(test_dataset, batch_size=1)
        all_labels, all_predictions, all_scores = test(model, test_loader, device)
        # calculate the acc for this dataset
        accuracy = accuracy_score(all_labels, all_predictions)
        print(f"Accuracy for {fake_dir}: {accuracy:.4f}")

        with open(log_file, "a") as f:
            f.write(f"Accuracy for {fake_dir}: {accuracy:.4f}\n")

        labels.extend(all_labels)
        predictions.extend(all_predictions)
        scores.extend(all_scores)

    # save labels and predictions, scores to a file "plot_data.json"
    with open("plot_data.json", "w") as f:
        json.dump({"labels": labels, "predictions": predictions, "scores": scores}, f)

    # calculate the overall acc
    accuracy = accuracy_score(labels, predictions)
    print(f"Overall Accuracy: {accuracy:.4f}")
    # calculate the overall eer
    eer, eer_threshold = compute_eer(labels, scores)
    print(f"Equal Error Rate (EER): {eer:.4f} at threshold {eer_threshold:.4f}")

    with open(log_file, "a") as f:
        f.write(f"Overall Accuracy with window: {accuracy:.4f}\n")
        # f.write(f"Equal Error Rate (EER) with window: {eer:.4f} at threshold {eer_threshold:.4f}\n")
    return accuracy, eer


def main():

    # 10, 11, 13, 15
    model_path = ""
    model = load_model(model_path)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    root_dir = ""
    testW_intrain(root_dir, model)


if __name__ == "__main__":
    main()
