import torch
from model.RawNet3 import RawNet3_detect
from model.RawNetBasicBlock import Bottle2neck
from dataset.dataset_aug_all import SpeechAugDataset
from torch.utils.data import DataLoader
import os
import socket
from hyperparameters import args
from regularization import mixup_data, mixup_criterion
from adversarial.frequency_attack import spectrum_attack
from lr_schedule import WarmUpCosineLR
from test_ood import test_intrain
from test_ood_with_window import testW_intrain
import wandb
import torchaudio

import warnings

warnings.filterwarnings("ignore", category=UserWarning)

# set random seed for reproducibility
torch.manual_seed(21)
torch.cuda.manual_seed(21)

wandb.init(
    # set the wandb project where this run will be logged
    project="voice-detection-robustness",

    # track hyperparameters and run metadata
    config=args
)


def load_model(num_classes=2):
    pt_file = torch.load('model.pt', map_location=torch.device('cpu'))['model']
    model = RawNet3_detect(encoder_type='ECA', nOut=256, sinc_stride=10, log_sinc=True, norm_sinc=True,
                           out_bn=True,
                           block=Bottle2neck, model_scale=8, context=True, summed=True, n_cls=num_classes)
    model.load_state_dict(pt_file, strict=False)

    return model


def load_dataset(root_dir, noise_dir, dataset_path):
    dataset = SpeechAugDataset(root_dir, dataset_path, noise_dir, max_frames=args.max_frames)
    train_size = int(0.98 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    return train_dataset, val_dataset


def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, device, epoch, log_file, args):
    model.train()
    running_loss = 0.0
    length = len(train_loader)
    for i, data in enumerate(train_loader, 0):
        optimizer.zero_grad()

        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        assert args.attack == "spectrum", "Only spectrum attack is supported for now"
        assert args.mixup, "Mixup is preferred for spectrum attack"

        n_fft = 1024
        hop_length = 512
        win_length = 1024

        # Perform STFT
        window = torch.hann_window(win_length)
        window = window.to(device)

        attacked_inputs = spectrum_attack(model, audio=inputs, label=labels, epsilon=args.spectrum_epsilon,
                                          alpha=args.spectrum_alpha, attack_iters=args.spectrum_attack_iters,
                                          restarts=args.spectrum_restarts, target_freq=args.spectrum_target_freq)

        # mixup clean inputs
        clean_inputs, clean_labels_a, clean_labels_b, clean_lam = mixup_data(inputs, labels, args.mixup_alpha)

        # mixup attacked inputs
        attacked_inputs, attacked_labels_a, attacked_labels_b, attacked_lam = mixup_data(attacked_inputs, labels,
                                                                                         args.mixup_alpha)

        # forward pass
        outputs_clean = model(clean_inputs)
        outputs_attacked = model(attacked_inputs)

        # mixup clean loss
        clean_loss = mixup_criterion(criterion, outputs_clean, clean_labels_a, clean_labels_b, clean_lam)

        # mixup attacked loss
        attacked_loss = mixup_criterion(criterion, outputs_attacked, attacked_labels_a, attacked_labels_b, attacked_lam)

        # total loss
        loss = clean_loss + args.gamma * attacked_loss

        loss.backward()
        optimizer.step()
        running_loss += loss.item() / (1 + args.gamma)
        if i % 50 == 49:
            log_msg = f"Epoch [{epoch + 1}], Step [{i + 1}/ {length}], lr:{scheduler.get_lr()[0]}, loss: {running_loss / 50:.4f}\n"
            print(log_msg)
            with open(log_file, "a") as f:
                f.write(log_msg)
            wandb.log({"train_loss": running_loss / 50})
            running_loss = 0.0
            # break
        scheduler.step()


def val_one_epoch(model, val_loader, criterion, device, log_file):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data in val_loader:
            # print(inputs.shape, labels.shape)
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_loss = running_loss / len(val_loader)
    val_acc = correct / total
    log_msg = f"Validation loss: {val_loss:.4f}, Validation accuracy: {val_acc:.4f}\n"
    print(log_msg)
    with open(log_file, "a") as f:
        f.write(log_msg)
    return val_loss, val_acc


def train(model, train_loader, val_loader, device, epochs=10, save_dir="checkpoints", args=None):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    total_steps = int(len(train_loader) * epochs)
    warmup_steps = int(args.warmup_epochs * len(train_loader))
    scheduler = WarmUpCosineLR(optimizer, warmup_steps=warmup_steps, total_steps=total_steps,
                               warmup_lr=args.warmup_lr, max_lr=args.lr, min_lr=args.min_lr)

    best_acc = 0.0
    os.makedirs(save_dir, exist_ok=True)

    log_file = f"{save_dir}/log.txt"
    #
    # accuracy, eer = test_intrain(root_dir, model=model)
    # accuracy_w, eer_w = testW_intrain(root_dir, model=model)
    # val_loss, val_acc = val_one_epoch(model, val_loader, criterion, device, log_file)
    #
    # wandb.log({"ood_accuracy": accuracy, "ood_eer": eer})
    # wandb.log({"ood_window_accuracy": accuracy_w, "ood_window_eer": eer_w})
    # wandb.log({"val_acc": val_acc, "val_loss": val_loss})

    # save args
    with open(log_file, "a") as f:
        f.write(str(args) + "\n")

    # save all py files
    os.makedirs(f"{save_dir}/code", exist_ok=True)
    os.system(f"cp *.py {save_dir}/code")

    start_epoch = 0

    if args.resume:
        # Check if a checkpoint exists
        checkpoint_files = [f for f in os.listdir(save_dir) if f.startswith("epoch_") and f.endswith(".pth")]
        if checkpoint_files:
            latest_checkpoint = max(checkpoint_files, key=lambda f: int(f.split('_')[1].split('.')[0]))
            checkpoint_path = os.path.join(save_dir, latest_checkpoint)
            checkpoint = torch.load(checkpoint_path)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            start_epoch = checkpoint['epoch'] + 1
            best_acc = checkpoint['best_acc']
            print(f"Resuming from epoch {start_epoch} using {latest_checkpoint}")

    for epoch in range(start_epoch, epochs):
        print(f"Epoch {epoch + 1}")
        with open(log_file, "a") as f:
            f.write(f"Epoch {epoch + 1}\n")
        train_one_epoch(model, train_loader, criterion, optimizer, scheduler, device, epoch, log_file, args)
        val_loss, val_acc = val_one_epoch(model, val_loader, criterion, device, log_file)
        wandb.log({"val_acc": val_acc, "val_loss": val_loss})

        # Save model checkpoint every epoch
        # Save model checkpoint every epoch
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_acc': best_acc,
        }, f"{save_dir}/epoch_{epoch + 1}.pth")

        if torch.cuda.device_count() > 1:
            accuracy, eer = test_intrain(root_dir, model.module)
            accuracy_w, eer_w = testW_intrain(root_dir, model.module)
        else:
            accuracy, eer = test_intrain(root_dir, model)
            accuracy_w, eer_w = testW_intrain(root_dir, model)

        wandb.log({"ood_accuracy": accuracy, "ood_eer": eer})
        wandb.log({"ood_window_accuracy": accuracy_w, "ood_window_eer": eer_w})

        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), f"{save_dir}/best_model.pth")

    print(args)


def main():
    nclasses = 2
    model = load_model(num_classes=nclasses)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"Using {device} device")

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = torch.nn.DataParallel(model)

    model.to(device)

    num_workers = 4
    noise_dir = ""  # set path to noise dir for augmentation
    root = ""  # set dataset root path
    dataset_path = {
        'real': {
            'speaker': [
                "VCTK/VCTK",
                "Librispeech/train-clean-100/",
                "Librispeech/train-clean-360/",
                "Librispeech/train-other-500/",
                "Wilds/wild",

                "ASRspoof2019/2019_LA_train/bonafide/",
                "ASRspoof2019/2019_LA_dev/bonafide/",
                "ASRspoof2019/2019_LA_eval/bonafide/",
                "voxceleb1/train_wav/",
                "voxceleb2/dev"

            ],
            'no_speaker': ['Narration/narration',
                           "ASRspoof2021/flac",
                           ]

        },
        'fake': {
            'speaker': [
                "VCTK/VCTK_metavoice/",
                "VCTK/VCTK_stylettsv2/",
                "VCTK/VCTK_voicecraft/",
                "VCTK/VCTK_whisperspeech/",
                "VCTK/VCTK_vokantts/",
                "VCTK/VCTK_xtts/",

                "Wilds/wild_metavoice/",
                "Wilds/wild_stylettsv2/",
                "Wilds/wild_voicecraft/",
                "Wilds/wild_whisperspeech/",
                "Wilds/wild_vokantts/",
                "Wilds/wild_xtts/",

                "Librispeech/train-clean-100_stylettsv2/",
                "Librispeech/train-clean-100_vokantts/",
                "Librispeech/train-clean-100_voicecraft/",

                "ASRspoof2019/2019_LA_train/spoof/",
                "ASRspoof2019/2019_LA_dev/spoof/",
                "ASRspoof2019/2019_LA_eval/spoof/",
                # "LibriSpeech/validonly_train-clean-100-elevenlab_apiGen_paid_chuck_1024/"

                "wavefake/generated_audio/",

                "ElevenLabs/",
                "eleven_paid_gen_ours/",
            ],
            'no_speaker': [
                "Narration/narration_stylettsv2/",
                "Narration/narration_vokantts/",
                "Narration/narration_whisperspeech/",
                "Narration/narration_voicecraft/",
            ]
        }
    }
    global root_dir
    root_dir = root

    train_dataset, val_dataset = load_dataset(root_dir, noise_dir, dataset_path)
    train_loader = DataLoader(train_dataset, batch_size=args.bs, num_workers=num_workers, shuffle=True, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=args.bs, num_workers=num_workers, shuffle=False, pin_memory=True)
    train(model, train_loader, val_loader, device, epochs=args.epochs, save_dir=args.ckp_dir, args=args)


if __name__ == "__main__":
    main()
