import torch
import json
from model.RawNet3 import RawNet3_detect
from model.RawNetBasicBlock import Bottle2neck
from dataset.datasets_aug import SpeechValDataset
from torch.utils.data import DataLoader

def load_model(model_path):
    pt_file = torch.load('../model.pt', map_location=torch.device('cpu'))['model']
    model = RawNet3_detect(encoder_type='ECA', nOut=256, sinc_stride=10, log_sinc=True, norm_sinc=True, out_bn=True,
                           block=Bottle2neck, model_scale=8, context=True, summed=True)
    model.load_state_dict(torch.load(model_path), strict=True)
    return model

def load_test_dataset(root_dir, source_dir_lists, fake_dir_list, speaker_annos):
    with open(speaker_annos, "r") as f:
        speaker_ids = json.load(f)
    test_speaker_ids = speaker_ids["test"]
    test_dataset = SpeechValDataset(root_dir, source_dir_lists, fake_dir_list, speaker_ids=test_speaker_ids)
    return test_dataset

def test(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    max_frames = 160
    max_audio = max_frames * 160 + 240
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data
            # truncate the input to max_audio = max_frames * 160 + 240
            inputs = inputs[:, :max_audio]

            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    test_acc = correct / total
    print(f"Test accuracy: {test_acc:.4f}")

def main():
    model_path = "checkpoints/best_model.pth"
    model = load_model(model_path)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    root_dir = ''
    source_dir_lists = ['source']
    fake_dir_list = ['metavoice', 'stylettsv2', 'xtts']
    speaker_annos = "speaker_ids.json"
    test_dataset = load_test_dataset(root_dir, source_dir_lists, fake_dir_list, speaker_annos)
    test_loader = DataLoader(test_dataset, batch_size=1)
    test(model, test_loader, device)

if __name__ == "__main__":
    main()