#%%
import sys
import torch
import numpy as np
from tqdm import tqdm
import torch
from scipy.special import softmax
import matplotlib.pyplot as plt

sys.path.append(".")
from image_uncertainty.cifar.cifar_evaluate import (
    load_model, get_eval_args, described_plot, cifar_test,
    misclassification_detection
)
from image_uncertainty.cifar.cifar_datasets import get_training_dataloader
from image_uncertainty.cifar import settings
from nuq import NuqClassifier, get_kernel
from experiments.imagenet_discrete import dump_ues
from spectral_normalized_models.ddu import (
    gmm_fit, gmm_evaluate, logsumexp, entropy
)


def get_embeddings(model, loader, features=False):
    labels = []
    embeddings = []
    for i, (images, batch_labels) in enumerate(tqdm(loader)):
        with torch.no_grad():
            if args.gpu:
                images = images.cuda()
            if features:
                model(images)
                embeddings.append(model.feature.cpu().numpy())
            else:
                embeddings.append(model(images).cpu().numpy())
        labels.extend(batch_labels.tolist())
        # if i == 5:
        #     break

    return np.concatenate(embeddings), np.array(labels)



class GmmNuq(NuqClassifier):
    def __init__(self, gmm, **kwargs):
        self.gmm = gmm
        super().__init__(**kwargs)

    def get_kde(self, x, batch_size=10):
        x_tensor = torch.tensor(x)[:, None, :].float()
        log_p = logsumexp(self.gmm.log_prob(x_tensor))[:, None]
        return log_p.numpy().astype(np.double)


from nuq.method import log_asymptotic_var, log_half_gaussian_mean
from nuq.method import logsumexp as lse



class NuqModified(NuqClassifier):
    pass
    def __init__(self, gmm=None, max_prob=False, one_sigma=False, **kwargs):
        self.gmm = gmm
        self.max_prob = max_prob
        self.one_sigma = one_sigma
        super().__init__(**kwargs)

    def get_kde(self, x, batch_size=10):
        if self.gmm is not None:
            x_tensor = torch.tensor(x)[:, None, :].float()
            log_p = logsumexp(self.gmm.log_prob(x_tensor))[:, None]
            return log_p.numpy().astype(np.double)
        else:
            return super().get_kde(x, batch_size)

    def predict_uncertainty(self, X, batch_size=50000):
        batches = [(i, i + batch_size) for i in range(0, len(X), batch_size)]
        Ue_total = np.array([])
        Ua_total = np.array([])
        Ut_total = np.array([])
        for batch in batches:
            X_batch = X[batch[0]: batch[1]]
            f_hat_x_full = self.get_kde(X_batch, batch_size=batch_size)
            output = self.predict_proba(X_batch, batch_size=batch_size)
            f_hat_y_x_full = output["probs"]
            f1_hat_y_x_full = output["probsm1"]

            f_hat_x = f_hat_x_full
            f_hat_y_x = f_hat_y_x_full
            f1_hat_y_x = f1_hat_y_x_full

            sigma_hat_est = np.max(f_hat_y_x + f1_hat_y_x, axis=1, keepdims=True)
            if self.one_sigma:
                sigma_hat_est = np.ones(sigma_hat_est.shape)

            if not self.use_uniform_prior:
                broadcast_shape = (1, sigma_hat_est.shape[0], sigma_hat_est.shape[1])
                sigma_hat_est = lse(np.concatenate(
                    [sigma_hat_est[None], np.log(self.coeff) * np.ones(shape=broadcast_shape)],
                    axis=0), axis=0)
            log_as_var = log_asymptotic_var(log_sigma_est=sigma_hat_est, log_f_est=f_hat_x,
                                            bandwidth=self.bandwidth,
                                            n=self.n_neighbors, dim=self.training_embeddings_.shape[1],
                                            squared_kernel_int=self.squared_kernel_int)

            Ue = log_half_gaussian_mean(asymptotic_var=log_as_var).squeeze()


            if self.max_prob:
                Ua = np.log(np.min((1 - softmax(X_batch, axis=-1) + self.coeff), axis=-1))
            else:
                Ua = np.min(f1_hat_y_x, axis=1, keepdims=True)
                if not self.use_uniform_prior:
                    Ua = lse(
                        np.concatenate([Ua[None], np.log(self.coeff) * np.ones(shape=broadcast_shape)], axis=0),
                        axis=0)
                Ua = Ua.squeeze()

            total_uncertainty = lse(np.concatenate([Ua[None], Ue[None]], axis=0), axis=0).squeeze()

            if Ue_total.shape[0] == 0:
                Ue_total = Ue
                Ua_total = Ua
                Ut_total = total_uncertainty
            else:
                Ue_total = np.concatenate([Ue_total, Ue])
                Ua_total = np.concatenate([Ua_total, Ua])
                Ut_total = np.concatenate([Ut_total, total_uncertainty])
        return {"epistemic": Ue_total, "aleatoric": Ua_total, "total": Ut_total}

    def predict_proba(self, X, batch_size=50000):
        x = softmax(X, axis=-1)
        # probs = np.max(x, axis=-1)
        # output = {"probs": np.log(x), "probsm1": np.log(1-x)}
        output = {"probs": np.log(1-x+self.coeff), "probsm1": np.log(x+self.coeff)}
        return output

#%%

def main(args):
    mode = {
        'features': True,
        'logits': False
    }[args.mode]

    if not args.cached:
        train_loader, val_loader = get_training_dataloader(
            settings.CIFAR100_TRAIN_MEAN,
            settings.CIFAR100_TRAIN_STD,
            num_workers=4,
            batch_size=args.b,
            shuffle=True,
            ood_name=args.ood_name,
            seed=args.data_seed
        )

        test_loader = cifar_test(args.b, False, args.ood_name)
        ood_loader = cifar_test(args.b, True, args.ood_name)

        model = load_model(args.net, args.weights, args.gpu)
        model.eval()

        x_train, y_train = get_embeddings(model, train_loader, features=mode)
        x_test, y_test = get_embeddings(model, test_loader, features=mode)
        x_ood, y_ood = get_embeddings(model, ood_loader, features=mode)

        with open('t_x_train.npy', 'wb') as f:
            np.save(f, x_train)
        with open('t_y_train.npy', 'wb') as f:
            np.save(f, y_train)
        with open('t_x_test.npy', 'wb') as f:
            np.save(f, x_test)
        with open('t_y_test.npy', 'wb') as f:
            np.save(f, y_test)
        with open('t_x_ood.npy', 'wb') as f:
            np.save(f, x_ood)
    else:
        base_dir = './'
        with open(f'{base_dir}t_x_train.npy', 'rb') as f:
            x_train = np.load(f)
        with open(f'{base_dir}t_y_train.npy', 'rb') as f:
            y_train = np.load(f)
        with open(f'{base_dir}t_x_test.npy', 'rb') as f:
            x_test = np.load(f)
        with open(f'{base_dir}t_y_test.npy', 'rb') as f:
            y_test = np.load(f)
        with open(f'{base_dir}t_x_ood.npy', 'rb') as f:
            x_ood = np.load(f)

    correct = (y_test == np.argmax(x_test, axis=-1))

    bandwidth = np.std(x_train, axis=0) / (2*np.pi)
    kernel = 'RBF'
    # nuq = NuqClassifier(
    #     bandwidth=bandwidth, tune_bandwidth=False, n_neighbors=50, precise_computation=True,
    #     kernel_type=kernel
    # )
    #
    # nuq.fit(X=x_train, y=y_train)
    # print(nuq.bandwidth[:3])
    #
    # ues_test = nuq.predict_uncertainty(x_test)
    # ues_ood = nuq.predict_uncertainty(x_ood)
    # try:
    #     for ue_type in ['epistemic', 'total', 'aleatoric']:
    #         print(ue_type)
    #         described_plot(
    #             ues_test[ue_type], ues_ood[ue_type], args.ood_name, args.net,
    #             title_extras='Nadaray-Watson'
    #         )
    # except:
    #     import ipdb; ipdb.set_trace()


    if args.exp == 'basic':
        print('\nBasic')

        nuq = NuqModified(
            bandwidth=bandwidth, tune_bandwidth=False, n_neighbors=50, precise_computation=True,
            kernel_type=kernel
        )
        nuq.fit(X=x_train, y=y_train)
        ues_test = nuq.predict_uncertainty(x_test)
        ues_ood = nuq.predict_uncertainty(x_ood)

        try:
            for ue_type in ['epistemic', 'total', 'aleatoric']:
                print(ue_type)
                described_plot(
                    ues_test[ue_type], ues_ood[ue_type], args.ood_name, args.net,
                    title_extras='Nadaray-Watson'
                )
        except:
            import ipdb;
            ipdb.set_trace()

    if args.exp == 'log':
        print('\n Logistic kernel')
        nuq = NuqModified(
            bandwidth=bandwidth, kernel_type='logistic', tune_bandwidth=False, n_neighbors=50, precise_computation=True,
        )
        nuq.fit(X=x_train, y=y_train)
        ues_test = nuq.predict_uncertainty(x_test, batch_size=5000)
        ues_ood = nuq.predict_uncertainty(x_ood, batch_size=5000)

        try:
            for ue_type in ['epistemic', 'total', 'aleatoric']:
                print(ue_type)
                described_plot(
                    ues_test[ue_type], ues_ood[ue_type], args.ood_name, args.net,
                    title_extras='Nadaray-Watson'
                )
        except:
            import ipdb
            ipdb.set_trace()

    if args.exp in ['gmm', 'gmm-kernel', 'sigma']:
        embeddings = torch.tensor(x_train)
        labels = torch.tensor(y_train)
        print('1')
        gaussians_model, jitter_eps = gmm_fit(embeddings=embeddings, labels=labels, num_classes=100)

    if args.exp == 'gmm':
        print('2')
        print(x_test.shape)
        logits = gaussians_model.log_prob(torch.tensor(x_test)[:, None, :].float())
        ood_logits = gaussians_model.log_prob(torch.tensor(x_ood)[:, None, :].float())
        print(logits.shape)

        print("GMM")
        print('DDU score')
        described_plot(
            -logsumexp(logits), -logsumexp(ood_logits), args.ood_name, args.net,
            title_extras='Nadaray-Watson'
        )
    #

    if args.exp == 'gmm-kernel':
        print('\n With gmm')
        nuq = NuqModified(
            gaussians_model,
            bandwidth=bandwidth, tune_bandwidth=False, n_neighbors=1, precise_computation=True,
        )
        nuq.fit(X=x_train, y=y_train)
        ues_test = nuq.predict_uncertainty(x_test, batch_size=5000)
        ues_ood = nuq.predict_uncertainty(x_ood, batch_size=5000)

        try:
            for ue_type in ['epistemic', 'total', 'aleatoric']:
                print(ue_type)
                described_plot(
                    ues_test[ue_type], ues_ood[ue_type], args.ood_name, args.net,
                    title_extras='Nadaray-Watson'
                )
        except:
            import ipdb
            ipdb.set_trace()

    if args.exp == 'sigma':
        print('\n Single sigma')
        nuq = NuqModified(
            gaussians_model, one_sigma=True,
            bandwidth=bandwidth, tune_bandwidth=False, n_neighbors=50, precise_computation=True,
        )

        nuq.fit(X=x_train, y=y_train)
        ues_test = nuq.predict_uncertainty(x_test)
        ues_ood = nuq.predict_uncertainty(x_ood)

        try:
            for ue_type in ['epistemic', 'total', 'aleatoric']:
                print(ue_type)
                described_plot(
                    ues_test[ue_type], ues_ood[ue_type], args.ood_name, args.net,
                    title_extras='Nadaray-Watson'
                )
        except:
            import ipdb
            ipdb.set_trace()

    if args.exp == 'mis':
        bandwidth = np.std(x_train, axis=0)

        nuq = NuqModified(
            max_prob=True,
            bandwidth=bandwidth, tune_bandwidth=False, n_neighbors=50, precise_computation=True,
        )

        nuq.fit(X=x_train, y=y_train)
        ues_test = nuq.predict_uncertainty(x_test)
        # ues_ood = nuq.predict_uncertainty(x_ood)
        misclassification_detection(correct, ues_test['total'])
        # ua = ues_test['aleatoric']
        # ue = ues_test['epistemic']
        # x_min = min(np.min(ue), np.min(ua))
        # x_max = max(np.max(ue), np.max(ua))
        # bins = np.linspace(x_min, x_max, 40)
        # plt.title(kernel)
        # plt.hist(ua[:400], bins=bins, alpha=0.4)
        # plt.hist(ue[:400], bins=bins, alpha=0.4)
        # plt.show()
        # print(ua[:20])
        # print(ue[:20])


if __name__ == '__main__':
    args = get_eval_args()
    print(args.__dict__)
    print(args.weights)
    main(args)
