#%%
import os

import numpy as np

import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

from src.models.ConvAE import ConvAE_anomaly_detector, LatentSpacePredictor
from src.utils.SubsetDigits import SubsetDigits
#from src.utils.get_sample import get_sample_from_class
from src.plotting_functions.digit_visualization import visualize_reconstruction, visualize_sample

from src.plotting_functions.plot_data_and_detector import plot_data_and_detector

#%% plot settings:

#https://stackoverflow.com/questions/3899980/how-to-change-the-font-size-on-a-matplotlib-plot
SMALL_SIZE = 12
MEDIUM_SIZE = 14
BIGGER_SIZE = 16

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

#%% train network on sample of digits

# whether to select model on test loss
select_on_test_loss = False #only works when recalculate_models=True or model has not been calculated yet

# Specify the digits to include
selected_digits = [5,7,4]
anomaly_digits = []

#Reload models if they already exist:
recalculate_models = False

#Set manual seed for reproducibility:
seed = 0
torch.manual_seed(seed)

# plotting options:
plot_min_x, plot_min_y = -12, -12
plot_max_x, plot_max_y = 12, 12

points_per_dimension = 100

log_scale=True

# IO
model_ID = "MNIST_" + str(selected_digits) + "_seed_"  + str(seed)
base_model_path = os.path.join("saved_models", model_ID)
base_figure_path = os.path.join("Figures", "MNIST_ConvAE", model_ID)

os.makedirs(base_model_path, exist_ok=True)
os.makedirs(base_figure_path, exist_ok=True)

# autoencoder parameters
input_channels = 1
encoding_dim = 2
learning_rate = 0.0001


save_epochs=[1, 2, 3, 4, 5, 6, 10, 15, 16, 17, 18, 19, 20]
save_epochs = [save_epoch-1 for save_epoch in save_epochs]
epochs = max(save_epochs)+1
#epochs = 2

# Load the MNIST dataset with the selected subset of digits
# 0-1 scale the data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0,), (1,))
])
# Create DataLoader for the training dataset
train_dataset = SubsetDigits(root='./data', train=True, download=True, transform=transform, digits=selected_digits)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Create DataLoader for the test dataset
test_dataset = SubsetDigits(root='./data', train=False, download=True, transform=transform, digits=selected_digits)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

additional_loaders = {"Test":test_loader}

for digit in anomaly_digits:
    # Create dataloader for anomaly dataset(s)
    anomaly_dataset = SubsetDigits(root='./data', train=False, download=True, transform=transform, digits=[digit])
    anomaly_loader = DataLoader(anomaly_dataset, batch_size=64, shuffle=True)

    additional_loaders[str(digit)] = anomaly_loader

detector = ConvAE_anomaly_detector(input_channels=input_channels, encoding_dim=encoding_dim, learning_rate=learning_rate, epochs=epochs, save_epochs=save_epochs, ID=model_ID, checkpoint_loss="Test")
   
all_checkpoints_saved = all([os.path.exists(os.path.join(base_model_path, f"epoch_{epoch}.pth")) for epoch in save_epochs])
if not recalculate_models and all_checkpoints_saved and max(save_epochs) == (epochs-1):
    print("Reloading model")

    detector.load_model(os.path.join(base_model_path, f"epoch_{max(save_epochs)}.pth"))
else:
    # Initialize and train the AE_anomaly_detector
    detector.fit(train_loader, additional_loaders=additional_loaders)

    # Plot the training loss + additional losses
    plt.figure()
    detector.plot_loss(plot_additionals=["Test"]+[str(digit) for digit in anomaly_digits])
    plt.legend()

    fig_name = os.path.join(base_figure_path, "loss_plot")
    plt.savefig(fig_name+".png", format="png")
    plt.savefig(fig_name+".pdf", format="pdf")
    plt.show()

    if select_on_test_loss:
        print("Reloading model at lowest test loss, epoch: " + str(detector.minimum_checkpoint_loss_epoch))
        detector.load_model(os.path.join(base_model_path, f"epoch_{detector.minimum_checkpoint_loss_epoch}.pth"))

latent_predictor = LatentSpacePredictor(detector)


#%% test whether network is reconstructing sensibly

sample, _ = train_dataset[1] 

# Reshape and convert the sample to a PyTorch tensor
sample_tensor = sample.unsqueeze(0)  

# Reconstruct the sample using the trained autoencoder
with torch.no_grad():
    reconstructed_sample = detector.autoencoder(sample_tensor)

reconstructed_sample = reconstructed_sample.detach().numpy()[0]

# Visualize the original and reconstructed samples
plt.figure()
visualize_reconstruction(sample.numpy(), reconstructed_sample)
plt.show()

#%% 

reconstruction_losses = []
all_labels = []

for i in range(len(train_dataset)):
    sample, label = train_dataset[i]
    sample_tensor = sample.unsqueeze(0)
    with torch.no_grad():
        reconstructed_sample = detector.autoencoder(sample_tensor)
        loss = nn.MSELoss(reduction="sum")(reconstructed_sample, sample_tensor).item()
    reconstruction_losses.append(loss)
    all_labels.append(label)

reconstruction_losses = np.array(reconstruction_losses)
all_labels = np.array(all_labels)

if len(anomaly_digits) > 0:
    anomaly_dataset = SubsetDigits(root='./data', train=False, download=True, transform=transform, digits=anomaly_digits)
    reconstruction_losses_anomalies = []
    all_labels_anomalies = []

    for i in range(len(anomaly_dataset)):
        sample, label = anomaly_dataset[i]
        sample_tensor = sample.unsqueeze(0)
        with torch.no_grad():
            reconstructed_sample = detector.autoencoder(sample_tensor)
            loss = nn.MSELoss(reduction="sum")(reconstructed_sample, sample_tensor).item()
        reconstruction_losses_anomalies.append(loss)
        all_labels_anomalies.append(label)

    reconstruction_losses_anomalies = np.array(reconstruction_losses_anomalies)
    all_labels_anomalies = np.array(all_labels_anomalies)


#%% plot histogram of recon losses of the different classes

# Get unique labels
unique_labels = np.unique(all_labels)


# Plot overlapping histograms
plt.figure(figsize=(10, 6))
for label in unique_labels:
    plt.hist(reconstruction_losses[all_labels == label], bins=50, alpha=0.5, label=f"Class {label}")


if len(anomaly_digits) > 0:
    unique_labels_anomalies = np.unique(all_labels_anomalies)
    for label in unique_labels_anomalies:
        plt.hist(reconstruction_losses_anomalies[all_labels_anomalies == label], bins=50, alpha=0.5, label=f"Class {label}")

plt.xlabel("Reconstruction Loss")
plt.ylabel("Frequency")
plt.legend()

fig_name = os.path.join(base_figure_path, "MSE_histogram")
plt.savefig(fig_name+".png", format="png")
plt.savefig(fig_name+".pdf", format="pdf")
plt.show()

#%%

# Get unique labels
unique_labels = np.unique(all_labels)

# Number of images to display per class
num_images = 3

# Iterate over each class
for label in unique_labels:
    # Get indices of samples belonging to the current class
    class_indices = np.where(all_labels == label)[0]

    # Sort reconstruction losses for the current class
    sorted_indices = class_indices[np.argsort(reconstruction_losses[class_indices])]

    # Select the indices of the 3 samples with the lowest reconstruction loss
    top_indices = sorted_indices[:num_images]

    # Plot the original and reconstructed images
    fig, axes = plt.subplots(num_images, 2, figsize=(8, 4 * num_images))
    fig.suptitle(f"Class {label} - Lowest Reconstruction Losses", fontsize=16)

    for i, idx in enumerate(top_indices):
        sample, _ = train_dataset[idx]
        sample_tensor = sample.unsqueeze(0)
        with torch.no_grad():
            reconstructed_sample = detector.autoencoder(sample_tensor)
            loss = nn.MSELoss(reduction="sum")(reconstructed_sample, sample).item()
        print(loss)
        # Original Sample
        axes[i, 0].imshow(sample.squeeze(), cmap="gray")
        axes[i, 0].set_title("Original Sample")
        axes[i, 0].axis("off")

        # Reconstructed Sample
        axes[i, 1].imshow(reconstructed_sample.squeeze().detach().numpy(), cmap="gray")
        axes[i, 1].set_title("Reconstructed Sample MSE: {:.2f}".format(loss))
        axes[i, 1].axis("off")

    plt.tight_layout()
    fig_name = os.path.join(base_figure_path, "lowest_MSE_label_"+str(label))
    plt.savefig(fig_name+".png", format="png", bbox_inches='tight')
    plt.savefig(fig_name+".pdf", format="pdf", bbox_inches='tight')
    plt.show()



#%%

# Encode the training data to get their latent representations

def get_latent_representations(detector, loader):
    latent_representations = []
    labels = []
    for batch in loader:
        images, batch_labels = batch
        with torch.no_grad():
            encoded = detector.autoencoder.encoder(images)
            encoded = encoded.view(encoded.size(0), -1)
            encoded = detector.autoencoder.bottleneck(encoded)
        latent_representations.append(encoded.numpy())
        labels.append(batch_labels.numpy())

    latent_representations = np.concatenate(latent_representations)
    labels = np.concatenate(labels)

    return latent_representations, labels

#%% Show recon profile over multiple epochs

for epoch in latent_predictor.detector.save_epochs:
    model_path = os.path.join(base_model_path, f"epoch_{epoch}.pth")

    latent_predictor.detector.load_model(model_path)

    latent_representations, labels = get_latent_representations(detector, train_loader)

    plt.figure()
    plot_data_and_detector(latent_predictor, normal_data=latent_representations, normal_labels=labels, anomalies=None, plot_eigenvector=False, resolution=points_per_dimension, 
                                plot_min_x=plot_min_x, plot_max_x=plot_max_x, plot_min_y=plot_min_y, plot_max_y=plot_max_y, log_scale=log_scale, colorbar_spacing="uniform", normal_data_markersize=5, normal_data_opacity=0.5)
    plt.title("epoch: "+ str(epoch+1))
    plt.xlabel(r"$\boldsymbol{Y}_1$")
    plt.ylabel(r"$\boldsymbol{Y}_2$")

    fig_name = os.path.join(base_figure_path, "loss_landscape_and_data_epoch_"+str(epoch))
    plt.tight_layout()
    plt.savefig(fig_name+".png", format="png")
    plt.savefig(fig_name+".pdf", format="pdf")
    plt.show()

#%% Only does something sensible for [4,5,7]
epoch=19 # for 4,5,7
model_path = os.path.join(base_model_path, f"epoch_{epoch}.pth")

latent_predictor.detector.load_model(model_path)

latent_representations, labels = get_latent_representations(detector, train_loader)

plt.figure()

plot_data_and_detector(latent_predictor, normal_data=latent_representations, normal_labels=labels, plot_eigenvector=False, resolution=points_per_dimension, 
                            plot_min_x=plot_min_x, plot_max_x=plot_max_x, plot_min_y=plot_min_y, plot_max_y=plot_max_y, log_scale=log_scale, colorbar_spacing="uniform", normal_data_markersize=5, normal_data_opacity=0.5)


fig_name = os.path.join(base_figure_path, "loss_landscape_data_anomaly_epoch_"+str(epoch))
plt.tight_layout()
plt.savefig(fig_name+".png", format="png")
plt.savefig(fig_name+".pdf", format="pdf")
plt.show()




#%% make plot for 7,4,5 including zoom
from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset, zoomed_inset_axes


epoch=19 
model_path = os.path.join(base_model_path, f"epoch_{epoch}.pth")

latent_predictor.detector.load_model(model_path)

latent_representations, labels = get_latent_representations(detector, train_loader)

plt.figure()
ax = plt.gca()

plot_data_and_detector(latent_predictor, normal_data=latent_representations, normal_labels=labels, plot_eigenvector=False, resolution=points_per_dimension, 
                            plot_min_x=plot_min_x, plot_max_x=plot_max_x, plot_min_y=plot_min_y, plot_max_y=plot_max_y, log_scale=log_scale, colorbar_spacing="uniform", normal_data_markersize=5, normal_data_opacity=0.5)
plt.xlabel(r"$\boldsymbol{Y}_1$")
plt.ylabel(r"$\boldsymbol{Y}_2$")

# first zoom
ax_inset = inset_axes(ax, width="25%", height="25%", loc='lower right')
plot_data_and_detector(latent_predictor, normal_data=latent_representations, normal_labels=labels, plot_eigenvector=False, resolution=points_per_dimension, 
                            plot_min_x=plot_min_x, plot_max_x=plot_max_x, plot_min_y=plot_min_y, plot_max_y=plot_max_y, log_scale=log_scale, colorbar_spacing="uniform", normal_data_markersize=10, normal_data_opacity=1, pure_plot=True, scatter_edgecolor='k', scatter_edgewidth=0.2)


# Set the zoomed-in limits
ax_inset.set_xlim(-5,-3)
ax_inset.set_ylim(-7, -4)

ax_inset.set_facecolor("#FFFFFF")

mark_inset(ax, ax_inset, loc1=1, loc2=3, fc="none", ec="0")

fig_name = os.path.join(base_figure_path, "loss_landscape_data_anomaly_with_zoom_epoch_"+str(epoch))
plt.tight_layout()
plt.savefig(fig_name+".png", format="png")
plt.savefig(fig_name+".pdf", format="pdf")
plt.show()

#%%
# Generate a specific adversarial sample with a low reconstruction loss
adversarial_latent_vector = torch.tensor([[-4.2, -5.2]], dtype=torch.float32)
#adversarial_latent_vector = torch.tensor([[-3.5, -4]], dtype=torch.float32) # for epoch 19 of 4,5,7 seed 0

# Decode the random latent vector to generate an artificial sample
with torch.no_grad():
    artificial_sample = detector.autoencoder.decoder_input(adversarial_latent_vector)
    artificial_sample = artificial_sample.view(artificial_sample.size(0), 32, 7, 7)
    artificial_sample = detector.autoencoder.decoder(artificial_sample)

with torch.no_grad():
    reconstructed_sample = detector.autoencoder(artificial_sample)
    loss = nn.MSELoss(reduction="sum")(reconstructed_sample, artificial_sample).item()

print(loss)

# Convert the artificial sample to a NumPy array and visualize it
artificial_sample = artificial_sample.detach().numpy()[0]
plt.figure()
visualize_sample(artificial_sample)
fig_name = os.path.join(base_figure_path, "adversarial_anomaly"+str(epoch))
plt.tight_layout()
plt.savefig(fig_name+".png", format="png", bbox_inches='tight')
plt.savefig(fig_name+".pdf", format="pdf", bbox_inches='tight')
plt.show()