import torch
from model.RawNet3 import RawNet3_detect
from model.RawNetBasicBlock import Bottle2neck
from dataset.datasets_aug import SpeechDataset
from torch.utils.data import DataLoader
import os
import socket
from hyperparameters import args
from regularization import mixup_data, mixup_criterion


def load_model(num_classes=2):
    pt_file = torch.load('model.pt', map_location=torch.device('cpu'))['model']
    model = RawNet3_detect(encoder_type='ECA', nOut=256, sinc_stride=10, log_sinc=True, norm_sinc=True, out_bn=True,
                           block=Bottle2neck, model_scale=8, context=True, summed=True, n_cls=num_classes)
    model.load_state_dict(pt_file, strict=False)
    return model


def load_dataset(root_dir, noise_dir, dataset_path):
    dataset = SpeechDataset(root_dir, dataset_path, noise_dir, max_frames=args.max_frames)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    return train_dataset, val_dataset


def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch, log_file, args):
    model.train()
    running_loss = 0.0
    length = len(train_loader)
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        if args.mixup:
            inputs, labels_a, labels_b, lam = mixup_data(inputs, labels, args.mixup_alpha)
            try:
                outputs = model(inputs)
                # loss = lam * criterion(outputs, labels_a) + (1 - lam) * criterion(outputs, labels_b)
                loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
            except:
                break
        else:
            outputs = model(inputs)
            loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 200 == 119:
            log_msg = f"Epoch [{epoch + 1}], Step [{i + 1}/ {length}] loss: {running_loss / 200:.4f}\n"
            print(log_msg)
            with open(log_file, "a") as f:
                f.write(log_msg)
            running_loss = 0.0
            # break


def val_one_epoch(model, val_loader, criterion, device, log_file):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data in val_loader:
            # print(inputs.shape, labels.shape)
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_loss = running_loss / len(val_loader)
    val_acc = correct / total
    log_msg = f"Validation loss: {val_loss:.4f}, Validation accuracy: {val_acc:.4f}\n"
    print(log_msg)
    with open(log_file, "a") as f:
        f.write(log_msg)
    return val_loss, val_acc


from test_ood import test_intrain
from test_ood_with_window import testW_intrain


def train(model, train_loader, val_loader, device, epochs=4, save_dir="checkpoints", args=None):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    best_acc = 0.0
    os.makedirs(save_dir, exist_ok=True)

    log_file = f"{save_dir}/log.txt"
    # test_intrain(root_dir, model)
    # testW_intrain(root_dir, model)

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}")
        with open(log_file, "a") as f:
            f.write(f"Epoch {epoch + 1}\n")
        train_one_epoch(model, train_loader, criterion, optimizer, device, epoch, log_file, args)
        val_loss, val_acc = val_one_epoch(model, val_loader, criterion, device, log_file)

        # Save model checkpoint every epoch
        torch.save(model.state_dict(), f"{save_dir}/epoch_{epoch + 1}.pth")

        test_intrain(root_dir, model)
        testW_intrain(root_dir, model)

        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), f"{save_dir}/best_model.pth")

    print(args)


def main():
    nclasses = 2
    model = load_model(num_classes=nclasses)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # model parallel
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = torch.nn.DataParallel(model)

    model.to(device)

    num_workers = 4
    noise_dir = ""  # set path to noise dir for augmentation
    root = ""  # set dataset root path
    dataset_path = {
        'real': {
            'speaker': [
                "VCTK/VCTK",
                "Librispeech/train-clean-100/",
                "Librispeech/train-clean-360/",
                "Librispeech/train-other-500/",
                "Wilds/wild",

                "ASRspoof2019/2019_LA_train/bonafide/",
                "ASRspoof2019/2019_LA_dev/bonafide/",
                "ASRspoof2019/2019_LA_eval/bonafide/",
                "voxceleb1/train_wav/",
                "voxceleb2/dev"

            ],
            'no_speaker': ['Narration/narration',
                           "ASRspoof2021/flac",
                           ]

        },
        'fake': {
            'speaker': [
                "VCTK/VCTK_metavoice/",
                "VCTK/VCTK_stylettsv2/",
                "VCTK/VCTK_voicecraft/",
                "VCTK/VCTK_whisperspeech/",
                "VCTK/VCTK_vokantts/",
                "VCTK/VCTK_xtts/",

                "Wilds/wild_metavoice/",
                "Wilds/wild_stylettsv2/",
                "Wilds/wild_voicecraft/",
                "Wilds/wild_whisperspeech/",
                "Wilds/wild_vokantts/",
                "Wilds/wild_xtts/",

                "Librispeech/train-clean-100_stylettsv2/",
                "Librispeech/train-clean-100_vokantts/",
                "Librispeech/train-clean-100_voicecraft/",

                "ASRspoof2019/2019_LA_train/spoof/",
                "ASRspoof2019/2019_LA_dev/spoof/",
                "ASRspoof2019/2019_LA_eval/spoof/",
                # "LibriSpeech/validonly_train-clean-100-elevenlab_apiGen_paid_chuck_1024/"

                "wavefake/generated_audio/",

                "ElevenLabs/",
                "eleven_paid_gen_ours/",
            ],
            'no_speaker': [
                "Narration/narration_stylettsv2/",
                "Narration/narration_vokantts/",
                "Narration/narration_whisperspeech/",
                "Narration/narration_voicecraft/",
            ]
        }
    }
    global root_dir
    root_dir = root

    train_dataset, val_dataset = load_dataset(root_dir, noise_dir, dataset_path)
    train_loader = DataLoader(train_dataset, batch_size=args.bs, num_workers=num_workers, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.bs, num_workers=num_workers, shuffle=False)
    train(model, train_loader, val_loader, device, epochs=args.epochs, save_dir=args.ckp_dir, args=args)


if __name__ == "__main__":
    main()
