import numpy as np
import torch
from torch.utils.data.dataloader import DataLoader
import matplotlib.pyplot as plt

from pado.metrics.accuracy import Accuracy
from classifier import Classifier
from dataset import VectorDataset

TEST0_PREFIX = "conformer-ctc-m-bn-128-b480-baseline-v3-test-clean"
TEST1_PREFIX = "conformer-ctc-m-bn-128-b480-baseline-v3-test-other"
ALIGN_DIR = "/project/LayerWiseAttnReuse/outputs/alignments"
TEST0_HIDDEN_DIR = f"/project/LayerWiseAttnReuse/outputs/{TEST0_PREFIX}/hiddens"
TEST1_HIDDEN_DIR = f"/project/LayerWiseAttnReuse/outputs/{TEST1_PREFIX}/hiddens"

LAYER_ID = 8


def build_net(num_classes: int):
    return Classifier(256, num_classes)


def build_dataloader(dataset):
    return DataLoader(dataset,
                      batch_size=128, shuffle=True, num_workers=2,
                      pin_memory=True, drop_last=True, timeout=120)


def eval():
    print("Eval start!")
    test_clean_dataset = VectorDataset(TEST0_HIDDEN_DIR, ALIGN_DIR, layer_id=LAYER_ID, exclude_silence=False)
    test_other_dataset = VectorDataset(TEST1_HIDDEN_DIR, ALIGN_DIR, layer_id=LAYER_ID, exclude_silence=False)
    print("Test clean data:", len(test_clean_dataset))
    print("Test other data:", len(test_other_dataset))
    test_clean_dataloader = build_dataloader(test_clean_dataset)
    test_other_dataloader = build_dataloader(test_other_dataset)

    net = build_net(test_clean_dataset.num_classes)  # 37
    net = net.cuda().eval()
    net.load_state_dict(torch.load(f"classifier_layer_{LAYER_ID}.pth", map_location="cuda"), strict=True)

    accuracy = Accuracy().cuda()

    with torch.no_grad():
        epoch_acc = 0
        epoch_samples = 0
        confusion_matrix = np.zeros((37, 37), dtype=np.float32)
        for it, (hiddens, labels) in enumerate(test_clean_dataloader):
            hiddens = hiddens.cuda()
            labels = labels.cuda()
            batch_size = labels.shape[0]

            prediction = net(hiddens)
            acc = accuracy(prediction, labels)

            _p = torch.argmax(prediction, dim=-1)
            for b in range(batch_size):
                confusion_matrix[labels[b], _p[b]] += 1

            epoch_acc += acc.item() * batch_size
            epoch_samples += batch_size

        test_acc = epoch_acc / epoch_samples
        print(f"Final: average test-clean accuracy: {test_acc:.3f}")

        confusion_matrix /= np.sum(confusion_matrix, axis=-1, keepdims=True)
        confusion_matrix = np.uint8(np.clip(confusion_matrix * 255.0, 0, 255))
        print(confusion_matrix)
        if "merge" in TEST0_HIDDEN_DIR:
            np.save(f"confusion_test_clean_merge_layer_{LAYER_ID}.npy", confusion_matrix)
        else:
            np.save(f"confusion_test_clean_layer_{LAYER_ID}.npy", confusion_matrix)
        # plt.figure()
        # plt.imshow(confusion_matrix, cmap="plasma")
        # plt.show()
        # plt.close()

        epoch_acc = 0
        epoch_samples = 0
        confusion_matrix = np.zeros((37, 37), dtype=np.float32)
        for it, (hiddens, labels) in enumerate(test_other_dataloader):
            hiddens = hiddens.cuda()
            labels = labels.cuda()
            batch_size = labels.shape[0]

            prediction = net(hiddens)
            acc = accuracy(prediction, labels)

            _p = torch.argmax(prediction, dim=-1)
            for b in range(batch_size):
                confusion_matrix[labels[b], _p[b]] += 1

            epoch_acc += acc.item() * batch_size
            epoch_samples += batch_size

        test_acc = epoch_acc / epoch_samples
        print(f"Final: average test-other accuracy: {test_acc:.3f}")

        confusion_matrix /= np.sum(confusion_matrix, axis=-1, keepdims=True)
        confusion_matrix = np.uint8(np.clip(confusion_matrix * 255.0, 0, 255))
        print(confusion_matrix)
        if "merge" in TEST1_HIDDEN_DIR:
            np.save(f"confusion_test_other_merge_layer_{LAYER_ID}.npy", confusion_matrix)
        else:
            np.save(f"confusion_test_other_layer_{LAYER_ID}.npy", confusion_matrix)
        # plt.figure()
        # plt.imshow(confusion_matrix, cmap="plasma")
        # plt.show()
        # plt.close()


def visualize():
    if "merge" in TEST1_HIDDEN_DIR:
        confusion_matrix = np.load(f"confusion_test_other_merge_layer_{LAYER_ID}.npy")
    else:
        confusion_matrix = np.load(f"confusion_test_other_layer_{LAYER_ID}.npy")
    confusion_matrix = confusion_matrix.astype(np.float32) / 255.0
    plt.figure()
    plt.imshow(confusion_matrix, cmap="plasma")
    plt.colorbar()
    plt.show()
    plt.close()


if __name__ == '__main__':
    # eval()
    visualize()
