#!/usr/bin/env python
# coding: utf-8

# In[1]:


import sys

sys.path.append(".")
sys.path.append("./loss-landscapes")

import os

import loss_landscapes
import loss_landscapes.metrics
import numpy as np
import pennylane as qml
import torch
import torch.nn.functional as F
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm import tqdm

from data_utils.aae_dataset import MNIST_AAE_Dataset
from data_utils.plot import plot_2d
from loss import FidLossDotProd, dot_product_loss, fidelity_loss
from models.state_generators import StateGenerator
from utils import add_noise, append_log, norm_image, resize

# In[2]:


LOSS_FN = {"state": FidLossDotProd, "MSE": torch.nn.MSELoss}


def seed_everything(seed):
    import random

    import numpy as np
    import torch

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


def train(config, loader: DataLoader):
    loss_path = os.path.join(config.checkpoint.logs, "loss.txt")
    # FIXME: rewrite log if exist, too ugly

    if os.path.exists(loss_path):
        f = open(loss_path, "w")
        f.close()

    superencoder = StateGenerator(config).to(config.device)

    loss_fn = LOSS_FN[config.state_generator.loss]().to(config.device)
    print(f"Using loss function: {loss_fn}")
    optimizer = torch.optim.Adam(superencoder.parameters(), **config.optimizer)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=config.n_epochs
    )

    writer = SummaryWriter(os.path.join(config.checkpoint.logs, "tensorboard"))

    for epoch in range(config.n_epochs):
        batch_idx = 0
        epoch_loss_sum = 0
        with tqdm(loader, leave=False) as bar:
            for batch in bar:
                # if batch_idx == 69:  # loss become nan at this iteration
                #     print()
                images = batch["images"]
                encoder_params = batch["encoder_params"]["weights"].view(
                    (-1, config.dataloader.batch_size)
                )

                images = resize(images, config.state_generator.aae_encoder.n_qubits).to(
                    config.device
                )
                # if (
                #     config.noise_factor
                #     and torch.rand(1).item() < config.noisy_probability
                # ):
                #     images = add_noise(
                #         images, config.noise_factor, config.device
                #     )  # adding noise make training sample more dynamic

                images = norm_image(images)
                if (
                    images.isnan().any()
                ):  # some inputs contain nan after 'resize_and_norm', no idea why, in current config, this happen when batch_idx==69
                    continue
                # plot_2d(images[0], figname=f"input{batch_idx}.pdf")

                pred = superencoder(images)

                if config.state_generator.loss == "MSE":
                    loss = loss_fn(pred, encoder_params)
                else:
                    loss = loss_fn(pred, superencoder.qc, images)

                optimizer.zero_grad()
                loss.backward()

                # found loss become nan after some time when training, try gradient clipping
                torch.nn.utils.clip_grad.clip_grad_norm_(superencoder.parameters(), 1.0)
                optimizer.step()
                batch_idx += 1
                epoch_loss_sum += loss.item()

                bar.set_postfix(loss=loss.item())

                writer.add_scalar(
                    "Loss/Step",
                    loss.item(),
                    batch_idx + epoch * config.dataloader.batch_size,
                )
                append_log(loss_path, loss.item())

        scheduler.step()

        writer.add_scalar("Loss/Epoch", epoch_loss_sum / (batch_idx + 1), epoch)
        print(
            f"Epoch [{epoch+1}/{config.n_epochs}], Loss: {epoch_loss_sum/(batch_idx+1):.4f}"
        )

    return superencoder


# # MSE Loss on MNIST

# In[10]:


OmegaConf.register_new_resolver("eval", eval)
version = "MSE_landscape.yaml"
config_dir = r"./configs/"
config = OmegaConf.load(os.path.join(config_dir, version))


# In[11]:


seed_everything(config.seed)

dataset = MNIST_AAE_Dataset(config.dataset.root)
loader = DataLoader(dataset, shuffle=True, **config.dataloader)

if not os.path.exists(config.checkpoint.logs):
    os.makedirs(config.checkpoint.logs)
    print("Train a new model...")
    model = train(config, loader)
    model.save(config.checkpoint.save_path)
else:
    superEncoder = StateGenerator(config=config).to(config.device)
    superEncoder.load(config.checkpoint.save_path, config.device)
    model = superEncoder
    print(f"{config.version} Model loaded...")


# In[12]:


images = next(iter(loader))["images"]
encoder_params = next(iter(loader))["encoder_params"]["weights"].view((-1, 32))
images = resize(images, config.state_generator.aae_encoder.n_qubits).to(config.device)
# if (
#     config.noise_factor
#     and torch.rand(1).item() < config.noisy_probability
# ):
#     images = add_noise(
#         images, config.noise_factor, config.device
#     )  # adding noise make training sample more dynamic
images = norm_image(images)


# In[13]:


loss_fn = torch.nn.MSELoss()

metric = loss_landscapes.metrics.Loss(loss_fn, images, encoder_params)


# In[14]:


loss_data_fin = loss_landscapes.random_plane(model, metric, steps=100)
### # normalize
### loss_data_fin = loss_data_fin / np.max(loss_data_fin)


# In[15]:


import matplotlib
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize

################# Matplotlib Global Conf #########################
fontsize = 16

plt.rcParams["text.usetex"] = True
plt.rcParams["xtick.labelsize"] = fontsize - 2
plt.rcParams["ytick.labelsize"] = fontsize - 2
# plt.rcParams['ztick.labelsize'] = fontsize - 2
# plt.rcParams["xtick.major.pad"] = -1
# plt.rcParams["ytick.major.pad"] = -1
plt.rcParams["axes.labelsize"] = fontsize
plt.rcParams["axes.labelweight"] = "bold"

################# Matplotlib Global Conf #########################

plt.contour(loss_data_fin, levels=30)
# plt.title("Loss Contours around Trained Model")
plt.show()
plt.savefig("mnist_landscape_mse_loss.pdf")


# In[16]:


import numpy as np

norm = Normalize(vmin=0, vmax=1)
STEPS = 100
fig = plt.figure()
ax = plt.axes(projection="3d")
X = np.array([[j for j in range(STEPS)] for i in range(STEPS)])
Y = np.array([[i for _ in range(STEPS)] for i in range(STEPS)])
ax.plot_surface(
    X,
    Y,
    loss_data_fin,
    rstride=1,
    cstride=1,
    cmap="viridis",
    edgecolor="none",
    norm=norm,
)
ax.set_zlim(0, 2)
# ax.set_title("Surface Plot of MSE Loss Landscape")
fig.show()
plt.savefig("mnist_3d_landscape_mse_loss.pdf")


# In[ ]:
