import torch
from torch.optim import SGD
from torch.utils.data.dataloader import DataLoader
from torch.nn import CrossEntropyLoss

from pado.metrics.accuracy import Accuracy
from classifier import Classifier
from dataset import VectorDataset

TRAIN0_PREFIX = "conformer-ctc-m-bn-128-b480-baseline-v3-dev-clean"
TRAIN1_PREFIX = "conformer-ctc-m-bn-128-b480-baseline-v3-dev-other"
TEST0_PREFIX = "conformer-ctc-m-bn-128-b480-baseline-v3-test-clean"
TEST1_PREFIX = "conformer-ctc-m-bn-128-b480-baseline-v3-test-other"
ALIGN_DIR = "/project/LayerWiseAttnReuse/outputs/alignments"
TRAIN0_HIDDEN_DIR = f"/project/LayerWiseAttnReuse/outputs/{TRAIN0_PREFIX}/hiddens"
TRAIN1_HIDDEN_DIR = f"/project/LayerWiseAttnReuse/outputs/{TRAIN1_PREFIX}/hiddens"
TEST0_HIDDEN_DIR = f"/project/LayerWiseAttnReuse/outputs/{TEST0_PREFIX}/hiddens"
TEST1_HIDDEN_DIR = f"/project/LayerWiseAttnReuse/outputs/{TEST1_PREFIX}/hiddens"

LAYER_ID = 16


def build_net(num_classes: int):
    return Classifier(256, num_classes)


def build_dataloader(dataset):
    return DataLoader(dataset,
                      batch_size=128, shuffle=True, num_workers=2,
                      pin_memory=True, drop_last=True, timeout=120)


def build_optimizer(net):
    return SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-3)


def train():
    print("Train start!")
    train0_dataset = VectorDataset(TRAIN0_HIDDEN_DIR, ALIGN_DIR, layer_id=LAYER_ID, exclude_silence=False)
    train1_dataset = VectorDataset(TRAIN1_HIDDEN_DIR, ALIGN_DIR, layer_id=LAYER_ID, exclude_silence=False)
    train_dataset = train0_dataset + train1_dataset
    print("Train data:", len(train_dataset))

    test_clean_dataset = VectorDataset(TEST0_HIDDEN_DIR, ALIGN_DIR, layer_id=LAYER_ID, exclude_silence=False)
    test_other_dataset = VectorDataset(TEST1_HIDDEN_DIR, ALIGN_DIR, layer_id=LAYER_ID, exclude_silence=False)
    print("Test clean data:", len(test_clean_dataset))
    print("Test other data:", len(test_other_dataset))
    train_dataloader = build_dataloader(train_dataset)
    test_clean_dataloader = build_dataloader(test_clean_dataset)
    test_other_dataloader = build_dataloader(test_other_dataset)

    net = build_net(test_clean_dataset.num_classes)  # 37
    net = net.cuda().train()

    optimizer = build_optimizer(net)

    criterion = CrossEntropyLoss().cuda()
    accuracy = Accuracy().cuda()

    for epoch in range(15):
        print(f"Epoch {epoch} start")

        if (epoch > 0) and (epoch % 3 == 0):
            optimizer.param_groups[0]["lr"] *= 0.1
        epoch_acc = 0
        epoch_samples = 0
        for it, (hiddens, labels) in enumerate(train_dataloader):
            hiddens = hiddens.cuda()
            labels = labels.cuda()
            batch_size = labels.shape[0]

            optimizer.zero_grad(set_to_none=True)

            prediction = net(hiddens)
            loss = criterion(prediction, labels)
            acc = accuracy(prediction, labels)
            loss.backward()

            optimizer.step()

            epoch_acc += acc.item() * batch_size
            epoch_samples += batch_size

        train_acc = epoch_acc / epoch_samples
        print(f"... Epoch {epoch}, average train accuracy: {train_acc:.3f}")
    torch.save(net.state_dict(), f"classifier_layer_{LAYER_ID}.pth")

    net.eval()
    with torch.no_grad():
        epoch_acc = 0
        epoch_samples = 0
        for it, (hiddens, labels) in enumerate(test_clean_dataloader):
            hiddens = hiddens.cuda()
            labels = labels.cuda()
            batch_size = labels.shape[0]

            prediction = net(hiddens)
            acc = accuracy(prediction, labels)

            epoch_acc += acc.item() * batch_size
            epoch_samples += batch_size

        test_acc = epoch_acc / epoch_samples
        print(f"Final: average test-clean accuracy: {test_acc:.3f}")

        epoch_acc = 0
        epoch_samples = 0
        for it, (hiddens, labels) in enumerate(test_other_dataloader):
            hiddens = hiddens.cuda()
            labels = labels.cuda()
            batch_size = labels.shape[0]

            prediction = net(hiddens)
            acc = accuracy(prediction, labels)

            epoch_acc += acc.item() * batch_size
            epoch_samples += batch_size

        test_acc = epoch_acc / epoch_samples
        print(f"Final: average test-other accuracy: {test_acc:.3f}")


if __name__ == '__main__':
    train()
