#%%
import sys
import random

import numpy as np
from matplotlib import pyplot as plt
from torchvision.datasets import MNIST, CIFAR10, SVHN
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
import torch
from torch import nn

sys.path.append('..')
# from nw_uncertainty.method.nw_method import  NewNW
from nuq.method import NuqClassifier
#%%
SEED = 1
ROTATE = True
if ROTATE:
    rotation = (30, 45)
else:
    rotation = 0


random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

train_transforms = transforms.Compose([
    transforms.ToTensor()
])
test_transforms = transforms.Compose([
    transforms.ToTensor()
])

mnist_train = MNIST('../checkpoint/data', download=True, train=True, transform=train_transforms)
mnist_test = MNIST('../checkpoint/data', download=True, train=False, transform=test_transforms)


train_loader = DataLoader(mnist_train, batch_size=512, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=512)

class SimpleConv(nn.Module):
    def __init__(self):
        super().__init__()
        width = 32
        self.layers = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(),
            nn.MaxPool2d(2),  # 14x14

            nn.Conv2d(16, 32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(),
            nn.MaxPool2d(2), # 7x7

            nn.Conv2d(32, 32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(),
            nn.AvgPool2d(2, padding=1), # 4x4

            nn.Flatten(),
            nn.Linear(512, width, bias=False),
            nn.BatchNorm1d(width),
            nn.LeakyReLU(),

        )

        self.feature = None
        self.linear =nn.Linear(width, 10)

    def forward(self, x):
        out = self.layers(x)
        self.feature = out.clone().detach()
        return self.linear(out)

epochs = 5
model = SimpleConv().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
model.train()

for e in range(epochs):
    epoch_losses = []
    for x_batch, y_batch in train_loader:
        x_batch = x_batch.cuda()
        y_batch = y_batch.cuda()
        optimizer.zero_grad()
        preds = model(x_batch)


        loss = criterion(preds, y_batch)
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.item())
    print(np.mean(epoch_losses))

model.eval()
correct = []
for x_batch, y_batch in test_loader:
    with torch.no_grad():
        x_batch = x_batch.cuda()
        preds = torch.argmax(model(x_batch).cpu(), dim=-1)
        correct.extend((preds == y_batch).tolist())
print('Accuracy', np.mean(correct))
#%%
rotation = (30, 45)
corrupted_transforms = transforms.Compose([
    transforms.RandomRotation(rotation),
    transforms.ToTensor()
])

mnist_corrupted = MNIST('../checkpoint/data', download=True, train=False, transform=corrupted_transforms)
corrupted_loader = DataLoader(mnist_corrupted, batch_size=10_000)
images, labels = next(iter(corrupted_loader))


with torch.no_grad():
    preds = torch.argmax(model(images.cuda()), dim=-1).cpu()
np.mean((preds == labels).numpy())
ood_transforms = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((28, 28)),
    transforms.ToTensor()
])

# cifar = CIFAR10(
#     '../checkpoint/data', download=True, train=False, transform=cifar_transforms
# )
ood_dataset = SVHN('./data', split='test', download=True, transform=ood_transforms)
ood_loader = DataLoader(ood_dataset, batch_size=10_000)
images_ood, labels_ood = next(iter(ood_loader))

def entropy(x):
    x_ = torch.softmax(x, dim=-1)
    return torch.sum(-x_*torch.log(x_), dim=-1)

with torch.no_grad():
    preds = model(images.cuda()).cpu()
    preds_ood = model(images_ood.cuda()).cpu()

preds_train = None
for x_batch, y_batch in train_loader:
    with torch.no_grad():
        preds_batch = model(x_batch.cuda()).cpu()
    if preds_train is None:
        preds_train = preds_batch
        y_train = y_batch
    else:
        preds_train = torch.cat((preds_train, preds_batch), dim=0)
        y_train = torch.cat((y_train, y_batch))

preds_train = preds_train.numpy()
y_train = y_train.numpy()


ue_mnist = 1 - torch.max(torch.softmax(preds, dim=-1), dim=-1).values
ue_ood = 1 - torch.max(torch.softmax(preds_ood, dim=-1), dim=-1).values
ue = torch.cat((ue_mnist, ue_ood)).numpy()
ue_entropy = torch.cat((entropy(preds), entropy(preds_ood)))

#%%
EMBEDDINGS = False

if EMBEDDINGS:
    train_loader = DataLoader(mnist_train, batch_size=60_000, shuffle=True)
    print(train_loader.batch_size)
    with torch.no_grad():
        train_images, y_train = next(iter(train_loader))
        y_train = y_train.numpy()
        model(train_images.cuda())
        train_embeddings = model.feature.cpu().numpy()

        model(images.cuda()).cpu()
        embeddings = model.feature.cpu().numpy()
        print(embeddings[:5, :3])

        model(images_ood.cuda())
        embeddings_ood = model.feature.cpu().numpy()
        print(embeddings_ood[:5, :3])

        nuq = NuqClassifier(strategy="isj", tune_bandwidth=True, n_neighbors=20)
        nuq.fit(X=train_embeddings, y=y_train)
        print('Fitted bandwidth', nuq.bandwidth)

        ue_nuq = np.concatenate((
            nuq.predict_uncertainty(embeddings)['total'],
            nuq.predict_uncertainty(embeddings_ood)['total'],
        ))
else:
    nuq = NuqClassifier(strategy="isj", tune_bandwidth=True, n_neighbors=50)
    nuq.fit(X=preds_train, y=y_train)
    print('Fitted bandwidth', nuq.bandwidth)

    ue_nuq = np.concatenate((
        nuq.predict_uncertainty(preds.numpy())['total'],
        nuq.predict_uncertainty(preds_ood.numpy())['total'],
    ))


#%%
# nuq = NuqClassifier(bandwidth=np.ones(32), strategy="classification", tune_bandwidth=True, n_neighbors=20)
# nuq.fit(X=train_embeddings, y=y_train)
# print('Fitted bandwidth', nuq.bandwidth)
#
# ue_nuq = np.concatenate((
#     nuq.predict_uncertainty(embeddings)['total'],
#     nuq.predict_uncertainty(embeddings_ood)['total'],
# ))

# from spectral_normalized_models.ddu import (
#     gmm_fit, logsumexp
# )
# gaussians_model, jitter_eps = gmm_fit(
#     embeddings=torch.tensor(train_embeddings), labels=torch.tensor(y_train), num_classes=100
# )
#
# ues_test_ddu = gaussians_model.log_prob(torch.tensor(embeddings)[:, None, :].float())
# ues_test_ddu = -logsumexp(ues_test_ddu).numpy().flatten()
# ues_ood_ddu = gaussians_model.log_prob(torch.tensor(embeddings_ood)[:, None, :].float())
# ues_ood_ddu = -logsumexp(ues_ood_ddu).numpy().flatten()
# ue_ddu = np.concatenate((ues_test_ddu, ues_ood_ddu))


#%%

ue_labels = np.concatenate((np.zeros(10000), np.ones(10000)))
xs = np.arange(0, 20001, 200)

def fractions(uncertainty):
    idxs = np.argsort(uncertainty)
    sorted_labels = ue_labels[idxs]
    return [np.sum(sorted_labels[:max_id]) for max_id in xs]

randomed_sums = fractions(np.random.random((20_000,)))
maxprobed_sums = fractions(ue)
entropy_sums = fractions(ue_entropy)
nwed_sums = fractions(ue_nuq)
# ddu_sums = fractions(ue_ddu)
optimal_sums = fractions(ue_labels)
#%%
font = {'family' : 'normal',
        # 'weight' : 'bold',
        'weight': 'normal',
        'size'   : 18}

import matplotlib
matplotlib.rc('font', **font)

linewidth = 3.5
plt.figure(figsize=(6, 5), dpi=150)
plt.subplots_adjust(left=0.21, bottom=0.13, right=0.93)
plt.title('Rotated MNIST vs grayscale SVHN')
plt.ylabel('SVHN objects included')
plt.xlabel('Total objects included')
plt.plot(xs, randomed_sums, label='Random', alpha=0.3, linewidth=linewidth)
plt.plot(xs, optimal_sums, label='Optimal', alpha=0.3, linewidth=linewidth)
plt.plot(xs, maxprobed_sums, label='MaxProb', linestyle='--', color='tab:green', linewidth=linewidth)
plt.plot(xs, entropy_sums, label='Entropy', linestyle=':', color='tab:red', linewidth=linewidth)
# plt.plot(xs, nwed_sums, label='DDU', linestyle='-', color='tab:orange')
plt.plot(xs, nwed_sums, label='NUQ', linestyle='-.', color='tab:purple', linewidth=linewidth)
plt.legend()
plt.show()
print('done')

#%%
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
kn_model = KNeighborsClassifier()
kn_model.fit(train_embeddings, y_train)

# accuracy_score(kn_m)
#%%
accuracy_score(labels, kn_model.predict(embeddings))

#%%
predictions = torch.argmax(preds, dim=-1).numpy()
imgs = images.numpy().reshape(-1, 28, 28)
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np

fig = plt.figure(figsize=(12., 3.5))
fig.set_tight_layout({"pad": 2})

grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(1, 4),  # creates 2x2 grid of axes
                 axes_pad=0.5,  # pad between axes in inch.
                 )

for i, ax in enumerate(grid):
    # Iterating over the grid returns the Axes.
    ax.axis('off')
    ax.imshow(imgs[np.random.randint(len(imgs))], cmap='gray')

plt.show()

#%%
# np.mean()
corrects = (torch.argmax(preds, dim=-1) == labels).numpy()

xs = np.arange(0, 10001, 400)

def splits(ues):
    idxs = np.argsort(ues)
    sorted_corrects = corrects[idxs]
    ys = [1] + [np.mean(sorted_corrects[:num]) for num in xs[1:]]
    return ys

plt.figure(figsize=(6, 5), dpi=150)
plt.subplots_adjust(left=0.15, bottom=0.13, right=0.95)
plt.title('Accuracy, MNIST rotated')
plt.ylabel("Accuracy")
plt.xlabel("Samples selected")
plt.plot(xs, splits(ue_mnist), label='MaxProb', linestyle='--', color='tab:green', linewidth=linewidth)
plt.plot(xs, splits(ue_entropy[:10000].numpy()), label='Entropy', linestyle=':', color='tab:red', linewidth=linewidth)
plt.plot(xs, splits(ue_nuq[:10000]), label='NUQ', linestyle='-.', color='tab:purple', linewidth=linewidth)
plt.legend()
plt.show()

#%%
np.mean((torch.argmax(preds_ood, dim=-1) == labels_ood).numpy())

def panel(imgs, num=4):
    fig = plt.figure(figsize=(num*3.0, 3.5))
    fig.set_tight_layout({"pad": 2})

    grid = ImageGrid(fig, 111,  # similar to subplot(111)
                     nrows_ncols=(1, num),  # creates 2x2 grid of axes
                     axes_pad=0.5,  # pad between axes in inch.
                     )


    for i, ax in enumerate(grid):
        # Iterating over the grid returns the Axes.
        ax.axis('off')
        ax.imshow(imgs[np.random.randint(len(imgs))], cmap='gray')

    plt.show()

#%%
from numpy.random import random
from skimage.filters import gaussian
from image_uncertainty.datasets.smooth_random import SmoothRandom

num = 10
image_size = (32, 64, 3)
noise_images = random((num, *image_size))
radiuses = 1.5 * random(num) + 1
smoothed = [gaussian(img, r, multichannel=3) for img, r in zip(noise_images, radiuses)]
panel(smoothed)

