from sklearn import svm
import time
from torchvision import datasets, transforms
import torch as ch
import numpy as np
import matplotlib.pyplot as plt
from matrix_utils import kernel_matrix_VICREG, kernel_matrix_contrastive
from sklearn.gaussian_process.kernels import RBF
from matplotlib.ticker import FormatStrFormatter
from tools import gen_bound, complexity


def get_kernel_matrix(kernel_fn, data, data_col):
    # input data and kernel function and returns kernel matrix
    # regime can be either 'ntk' or 'nngp'
    kernel = kernel_fn(data, data_col)
    return kernel


def VICReg(
    train_X,
    train_y,
    test_X,
    test_y,
    N,
    contrastive=False,
    dimension=None,
    sigma=1,
    beta=0.1,
    C=1000,
    do_baselines=True,
):
    perfs = []
    svm_clf = svm.SVC(kernel="precomputed", C=C)
    kernel_fn = RBF(sigma)
    K_ss = get_kernel_matrix(kernel_fn, train_X, train_X)
    K_sx_t = get_kernel_matrix(kernel_fn, train_X, test_X)

    if do_baselines:
        # oracle performance
        svm_clf.fit(K_ss, train_y)
        predict = svm_clf.predict(K_ss)
        train_accuracy = np.equal(predict, train_y).astype("float").mean() * 100

        predict = svm_clf.predict(K_sx_t.T)
        test_accuracy = np.equal(predict, test_y).astype("float").mean() * 100
        print("Oracle perf.", train_accuracy, test_accuracy)
        perfs.append(train_accuracy)
        perfs.append(test_accuracy)
        perfs.append(
            complexity(
                K_ss + np.eye(K_ss.shape[0]) / C,
                (train_y > 5).astype("float") * 2 - 1,
            )
        )

        # oracle performance
        svm_clf.fit(K_ss[:N, :N], train_y[:N])
        predict = svm_clf.predict(K_ss[:N].T)
        train_accuracy = np.equal(predict, train_y).astype("float").mean() * 100

        predict = svm_clf.predict(K_sx_t[:N].T)
        test_accuracy = np.equal(predict, test_y).astype("float").mean() * 100
        print("Low Oracle perf.", train_accuracy, test_accuracy)
        perfs.append(train_accuracy)
        perfs.append(test_accuracy)
        perfs.append(
            complexity(
                K_ss[:N, :N] + np.eye(N) / C,
                (train_y[:N] > 5).astype("float") * 2 - 1,
            )
        )
    else:
        perfs.extend([None, None, None, None, None, None])
    # SSL performance
    K_sx = get_kernel_matrix(kernel_fn, train_X, train_X[:N])

    # define adjacency matrix
    K = len(train_X) // N  # <- number of augmentations
    A = np.kron(1 - np.eye(K), np.eye(N))

    A[np.arange(len(train_X)), np.arange(len(train_X))] = 0
    A /= A.max()

    if contrastive:
        K_inv = kernel_matrix_contrastive(K_ss, A, dimension=dimension)
    else:
        K_inv = kernel_matrix_VICREG(K_ss, A, dimension=dimension, beta=beta)

    K_tt = K_sx.T @ K_inv @ K_sx
    svm_clf.fit(K_tt, train_y[:N])
    predict = svm_clf.predict(K_tt)
    train_accuracy = np.equal(predict, train_y[:N]).astype("float").mean() * 100

    K_tt2 = K_sx_t.T @ K_inv @ K_sx
    predict = svm_clf.predict(K_tt2)
    test_accuracy = np.equal(predict, test_y).astype("float").mean() * 100
    perfs.append(train_accuracy)
    perfs.append(test_accuracy)
    perfs.append(
        complexity(
            K_tt + np.eye(N) / C,
            (train_y[:N] > 5).astype("float") * 2 - 1,
        )
    )
    print("SSL perf.", train_accuracy, test_accuracy)
    return perfs


def create_loaders(name="MNIST", batch_size=32, seed=42):

    ch.manual_seed(seed)

    train_kwargs = {
        "batch_size": batch_size,
        "num_workers": 40,
        "pin_memory": False,
        "shuffle": True,
    }
    test_kwargs = {
        "batch_size": 10000,
        "num_workers": 40,
        "pin_memory": False,
        "shuffle": False,
    }

    test_transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    if name == "MNIST":
        # now takes care of testing dataset
        dataset2 = datasets.MNIST(
            "../DATASETS/MNIST/", train=False, transform=test_transform
        )
        # now takes care of training dataset
        dataset1 = datasets.MNIST(
            "../DATASETS/MNIST/", train=True, download=False, transform=test_transform
        )
    elif name == "EMNIST":
        # now takes care of testing dataset
        dataset2 = datasets.EMNIST(
            "../DATASETS/EMNIST/",
            train=False,
            split="letters",
            transform=test_transform,
            download=True,
        )
        # now takes care of training dataset
        dataset1 = datasets.EMNIST(
            "../DATASETS/EMNIST/",
            train=True,
            split="letters",
            download=True,
            transform=test_transform,
        )
        test_kwargs["batch_size"] = 20800
    elif name == "omni":
        # now takes care of testing dataset
        dataset2 = datasets.Omniglot(
            "../DATASETS/Omniglot/",
            background=False,
            transform=test_transform,
            download=True,
        )
        # now takes care of training dataset
        dataset1 = datasets.Omniglot(
            "../DATASETS/Omniglot/",
            background=True,
            download=True,
            transform=test_transform,
        )
        test_kwargs["batch_size"] = 13180

    test_loader = ch.utils.data.DataLoader(dataset2, **test_kwargs)
    for (test_X, test_y) in test_loader:
        break

    train_loader = ch.utils.data.DataLoader(dataset1, **train_kwargs)
    for (train_X, train_y) in train_loader:
        break
    return [train_X, train_y], [test_X, test_y]


def augment_train(train_X, train_y, num_augs, option, seed=42):

    ch.manual_seed(seed)

    if option == 0:
        train_transform = transforms.Compose([transforms.GaussianBlur((5, 5))])
    elif option == 1:
        train_transform = transforms.Compose(
            [
                transforms.RandomAffine(
                    (-5, 5), translate=(0.05, 0.05), scale=(0.95, 1.05)
                ),
            ]
        )
    X = [train_X.clone()]
    y = [train_y.clone()]
    for _ in range(num_augs):
        X.append(ch.stack([train_transform(x) for x in X[0]]))
        y.append(y[0])
    X = ch.cat(X)
    y = ch.cat(y)
    return X, y


def main():
    for dataset in ["MNIST"]:
        train_base, test = create_loaders(dataset, batch_size=100)
        train_base[0] /= (train_base[0] ** 2).sum((1, 2, 3), keepdims=True).sqrt()
        test[0] /= (test[0] ** 2).sum((1, 2, 3), keepdims=True).sqrt()
        train = augment_train(*train_base, num_augs=20, option=1)

        print(
            "train set shape:",
            train[0].shape[0],
            "test set shape",
            test[0].shape[0],
        )
        for C in [1000, 100, 10, 1, 0.1]:
            for dimension in [16, 32, 64, 128, 256, 512]:
                perfs = VICReg(
                    train[0].flatten(1).numpy(),
                    train[1].numpy(),
                    test[0].flatten(1).numpy(),
                    test[1].numpy(),
                    100,
                    contrastive=True,
                    C=C,
                    dimension=dimension,
                    do_baselines=dimension == 16,
                )
                np.savez(
                    f"logs/comparaison_{dataset}_{C}_{dimension}_contrastive_data_bounds.npz",
                    perfs,
                )

        C = 1000
        # for beta in [0.0001, 0.001, 0.01, 0.1, 1]:
        for beta in [10]:
            for dimension in [16, 32, 64, 128, 256, 512]:
                perfs = VICReg(
                    train[0].flatten(1).numpy(),
                    train[1].numpy(),
                    test[0].flatten(1).numpy(),
                    test[1].numpy(),
                    100,
                    contrastive=False,
                    C=C,
                    dimension=dimension,
                    do_baselines=dimension == 16 and beta == 0.0001,
                    beta=beta,
                )
                np.savez(
                    f"logs/comparaison_{dataset}_{C}_{dimension}_{beta}_noncontrastive_data.npz",
                    perfs,
                )

    for dataset in ["MNIST", "EMNIST"]:
        Ns = [16, 64, 256]
        N_augs = [16, 32, 64, 128, 256, 512, 1024, 2048]
        options = [0, 1]
        for N in Ns:
            train_base, test = create_loaders(dataset, batch_size=N)
            train_base[0] /= (train_base[0] ** 2).sum((1, 2, 3), keepdims=True).sqrt()
            test[0] /= (test[0] ** 2).sum((1, 2, 3), keepdims=True).sqrt()
            for augs in N_augs:
                if N * augs > 50000:
                    break
                for option in options:
                    train = augment_train(*train_base, num_augs=augs, option=option)

                    print(
                        "train set shape:",
                        train[0].shape[0],
                        "test set shape",
                        test[0].shape[0],
                    )
                    perfs = VICReg(
                        train[0].flatten(1).numpy(),
                        train[1].numpy(),
                        test[0].flatten(1).numpy(),
                        test[1].numpy(),
                        N,
                        contrastive=False,
                        beta=1,
                    )
                    np.savez(
                        f"logs/VIC_{dataset}_{N}_{augs}_{option}_noncontrastive_saving_data.npz",
                        perfs,
                    )


def plot():
    # ablation on contrastive learning dimension + regularization
    all_train = []
    all_test = []
    Cs = [1000, 100, 10, 1, 0.1]
    Ds = [16, 32, 64, 128, 256, 512]
    for C in Cs:
        for dimension in Ds:
            data = np.load(
                f"logs/comparaison_MNIST_{C}_{dimension}_contrastive_data.npz",
                allow_pickle=True,
            )["arr_0"]
            if C == 1000 and dimension == 16:
                baseline = data[:4]
            all_train.append(data[-2])
            all_test.append(data[-1])
    baseline = baseline.round(2)
    all_train = np.array(all_train).reshape((5, 6)).round(2)
    all_test = np.array(all_test).reshape((5, 6)).round(2)

    fig, axs = plt.subplots(1, 2, figsize=(8, 4), sharex="all", sharey="all")
    axs[0].imshow(all_train, interpolation="nearest", origin="lower")
    axs[0].set_xticks(range(len(Ds)), Ds, fontsize=12)
    axs[0].set_yticks(range(len(Cs)), Cs, fontsize=12)
    axs[0].set_xlabel("representation dimension (K)", fontsize=12)
    axs[0].set_ylabel(r"$inv. \ell_2$ regularization", fontsize=12)
    axs[0].set_title(f"sup. baselines: {baseline[0]}, {baseline[2]} (train)")
    for j in range(len(Ds)):
        for i in range(len(Cs)):
            axs[0].text(
                j, i, all_train[i, j], ha="center", va="center", c="red", fontsize=12
            )

    axs[1].imshow(all_test, interpolation="nearest", origin="lower")
    axs[1].set_xticks(range(len(Ds)), Ds, fontsize=12)
    axs[1].set_yticks(range(len(Cs)), Cs, fontsize=12)
    axs[1].set_xlabel("representation dimension (K)", fontsize=12)
    axs[1].set_title(f"sup. baselines: {baseline[1]}, {baseline[3]} (test)")
    for j in range(len(Ds)):
        for i in range(len(Cs)):
            axs[1].text(
                j,
                i,
                all_test[i, j],
                ha="center",
                va="center",
                c="red",
                fontweight="normal" if all_test[i, j] < baseline[1] else "bold",
                fontsize=12,
            )

    plt.subplots_adjust(0.13, 0.03, 0.99, 0.99, 0.02, 0.02)
    plt.savefig("contrastive_ablation.png")
    plt.close()

    # ablation (bound) on contrastive learning dimension + regularization
    all_train = []
    all_test = []
    Cs = [1000, 100, 10, 1, 0.1]
    Ds = [16, 32, 64, 128, 256, 512]
    for C in Cs:
        for dimension in Ds:
            data = np.load(
                f"logs/comparaison_MNIST_{C}_{dimension}_contrastive_data_bounds.npz",
                allow_pickle=True,
            )["arr_0"]
            if C == 1000 and dimension == 16:
                baseline = np.array([data[2], data[5]])
            all_test.append(data[-1])
    baseline = baseline.round(2)
    all_test = np.array(all_test).reshape((5, 6)).astype("int")
    fig, axs = plt.subplots(1, 1, figsize=(4, 4), sharex="all", sharey="all")

    axs.imshow(all_test, interpolation="nearest", origin="lower")
    axs.set_xticks(range(len(Ds)), Ds, fontsize=12)
    axs.set_yticks(range(len(Cs)), Cs, fontsize=12)
    axs.set_xlabel("representation dimension (K)", fontsize=12)
    axs.set_ylabel(r"$inv. \ell_2$ regularization", fontsize=12)
    for j in range(len(Ds)):
        for i in range(len(Cs)):
            axs.text(
                j,
                i,
                all_test[i, j],
                ha="center",
                va="center",
                c="red",
                fontsize=8,
            )

    axs.set_title(f"sup. baselines: {baseline[0]}, {baseline[1]}")
    plt.subplots_adjust(0.25, 0.05, 0.99, 0.95, 0.02, 0.02)
    plt.savefig("contrastive_bound_ablation.png")
    plt.close()

    # ablation on non-contrastive learning dimension + regularization

    all_train = []
    all_test = []
    C = 1000
    Bs = [0.0001, 0.001, 0.01, 0.1, 1]
    Ds = [16, 32, 64, 128, 256, 512]
    for B in Bs:
        for dimension in Ds:
            data = np.load(
                f"logs/comparaison_MNIST_{C}_{dimension}_{B}_noncontrastive_data.npz",
                allow_pickle=True,
            )["arr_0"]
            if B == 0.0001 and dimension == 16:
                baseline = data[:4]
            all_train.append(data[-2])
            all_test.append(data[-1])
    baseline = baseline.round(2)
    all_train = np.array(all_train).reshape((5, 6)).round(2)
    all_test = np.array(all_test).reshape((5, 6)).round(2)

    fig, axs = plt.subplots(1, 2, figsize=(8, 4), sharex="all", sharey="all")

    axs[0].imshow(all_train, interpolation="nearest", origin="lower")
    axs[0].set_xticks(range(len(Ds)), Ds, fontsize=12)
    axs[0].set_yticks(range(len(Bs)), Bs, fontsize=12)
    axs[0].set_xlabel("representation dimension (K)", fontsize=12)
    axs[0].set_ylabel(r"$\beta$ parameter", fontsize=12)
    axs[0].set_title(f"sup. baselines: {baseline[0]}, {baseline[2]} (train)")
    for j in range(len(Ds)):
        for i in range(len(Bs)):
            axs[0].text(
                j, i, all_train[i, j], ha="center", va="center", c="red", fontsize=12
            )

    axs[1].imshow(all_test, interpolation="nearest", origin="lower")
    axs[1].set_xticks(range(len(Ds)), Ds, fontsize=12)
    axs[1].set_yticks(range(len(Bs)), Bs, fontsize=12)
    axs[1].set_xlabel("representation dimension (K)", fontsize=12)
    axs[1].set_title(f"sup. baselines: {baseline[1]}, {baseline[3]} (test)")
    for j in range(len(Ds)):
        for i in range(len(Bs)):
            axs[1].text(
                j,
                i,
                all_test[i, j],
                ha="center",
                va="center",
                c="red",
                fontweight="normal" if all_test[i, j] < baseline[1] else "bold",
                fontsize=12,
            )

    plt.subplots_adjust(0.13, 0.03, 0.99, 0.99, 0.02, 0.02)
    plt.savefig("noncontrastive_ablation.png")
    plt.close()

    # contrastive classification tasks

    Ns = [16, 64, 256]
    N_augs = [16, 32, 64, 128, 256, 512, 1024, 2048]
    options = [0, 1]
    for dataset in ["MNIST", "EMNIST"]:
        fig, axs = plt.subplots(2, len(Ns), sharex="col", figsize=(2.5 * len(Ns), 5))
        for c, N in enumerate(Ns):
            axs[0, c].set_title(f"N={N}")
            axs[1, c].set_xlabel("# augmentations (log2)", fontsize=12)
            for option in options:
                axs[option, c].yaxis.set_major_formatter(FormatStrFormatter("%.0f"))
                axs[option, c].xaxis.set_major_formatter(FormatStrFormatter("%.0f"))
                datas = []
                for augs in N_augs:
                    if N * augs > 50000:
                        break
                    if dataset == "EMNIST":
                        data = np.load(
                            f"logs/{dataset}_{N}_{augs}_{option}_saving_data.npz"
                        )["arr_0"]
                    else:
                        data = np.load(f"logs/{N}_{augs}_{option}_saving_data.npz")[
                            "arr_0"
                        ]
                    datas.append(data)
                datas = np.stack(datas)
                axs[option, c].plot(
                    np.log2(N_augs[: len(datas)]),
                    datas[:, 1],
                    "-o",
                    color="black",
                    label="test oracle full",
                )
                axs[option, c].plot(
                    np.log2(N_augs[: len(datas)]),
                    datas[:, 3],
                    "-o",
                    color="blue",
                    label="test oracle min",
                )
                axs[option, c].plot(
                    np.log2(N_augs[: len(datas)]),
                    datas[:, 5],
                    "-o",
                    color="red",
                    label="test SSL",
                )
            axs[1, c].tick_params(axis="y", labelsize=12)
            axs[0, c].tick_params(axis="y", labelsize=12)
            axs[1, c].tick_params(axis="x", labelsize=12)
        axs[0, 0].set_ylabel("Test accuracy", fontsize=12)
        axs[1, 0].set_ylabel("Test accuracy", fontsize=12)
        plt.subplots_adjust(0.08, 0.12, 0.99, 0.95, 0.20, 0.05)
        plt.savefig(f"summary_SSL_{dataset}.png")
        plt.close()

    # noncontrastive classification tasks

    Ns = [16, 64, 256]
    N_augs = [16, 32, 64, 128, 256, 512, 1024, 2048]
    options = [0, 1]
    for dataset in ["MNIST", "EMNIST"]:
        fig, axs = plt.subplots(2, len(Ns), sharex="col", figsize=(2.5 * len(Ns), 5))
        for c, N in enumerate(Ns):
            axs[0, c].set_title(f"N={N}")
            axs[1, c].set_xlabel("# augmentations (log2)", fontsize=12)
            for option in options:
                axs[option, c].yaxis.set_major_formatter(FormatStrFormatter("%.0f"))
                axs[option, c].xaxis.set_major_formatter(FormatStrFormatter("%.0f"))
                datas = []
                for augs in N_augs:
                    if N * augs > 50000:
                        break
                    data = np.load(
                        f"logs/VIC_{dataset}_{N}_{augs}_{option}_noncontrastive_saving_data.npz"
                    )["arr_0"]
                    datas.append(data)
                datas = np.stack(datas)
                axs[option, c].plot(
                    np.log2(N_augs[: len(datas)]),
                    datas[:, 1],
                    "-o",
                    color="black",
                    label="test oracle full",
                )
                axs[option, c].plot(
                    np.log2(N_augs[: len(datas)]),
                    datas[:, 3],
                    "-o",
                    color="blue",
                    label="test oracle min",
                )
                axs[option, c].plot(
                    np.log2(N_augs[: len(datas)]),
                    datas[:, 5],
                    "-o",
                    color="red",
                    label="test SSL",
                )
            axs[1, c].tick_params(axis="y", labelsize=12)
            axs[0, c].tick_params(axis="y", labelsize=12)
            axs[1, c].tick_params(axis="x", labelsize=12)
        axs[0, 0].set_ylabel("Test accuracy", fontsize=12)
        axs[1, 0].set_ylabel("Test accuracy", fontsize=12)
        plt.subplots_adjust(0.08, 0.12, 0.99, 0.95, 0.20, 0.05)
        plt.savefig(f"summary_SSL_nc_{dataset}.png")
        plt.close()


if __name__ == "__main__":
    main()
    plot()
