#!/usr/bin/env python
# coding: utf-8

# ## Imports and Setting

# In[1]:


import os
import random

import matplotlib.pyplot as plt
import pennylane as qml
import torch
import torch.nn as nn
import torch.nn.functional as F
from pennylane import numpy as np
from pennylane.templates import (BasicEntanglerLayers, RandomLayers,
                                 StronglyEntanglingLayers)
from torch import optim
from torch.optim.lr_scheduler import *

# In[2]:


processed_data_dir = r"./mnist/processed/"

# n_qubits = 4 + 1  # 4 data qubit, 1 post selection
n_qubits = 4
num_encoder_layer = 8
seed = 42

batch_size = 1
num_epochs = 10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)


# ## Data

# In[3]:


# from data_utils import MNIST

# dataset = MNIST(
#     root='./mnist',
#     train_valid_split_ratio=[0.9, 0.1],
#     digits_of_interest=[3, 6],
#     n_test_samples=3000,
#     n_train_samples=5000
# )

# dataflow = dict()
# for split in dataset:
#     sampler = torch.utils.data.RandomSampler(dataset[split])
#     dataflow[split] = torch.utils.data.DataLoader(
#         dataset[split],
#         batch_size=len(dataset[split]),
#         sampler=sampler,
#         num_workers=0,
#         pin_memory=True)

# train_data = next(iter(dataflow["train"]))
# val_data = next(iter(dataflow["valid"]))
# test_data = next(iter(dataflow["test"]))

data = {
    "train": {
        "images": train_data["image"],
        "digits": train_data["digit"],
        "encoder_params": [torch.nan for _ in range(len(train_data["digit"]))],
    },
    "valid": {
        "images": val_data["image"],
        "digits": val_data["digit"],
        "encoder_params": [torch.nan for _ in range(len(val_data["digit"]))],
    },
    "test": {
        "images": test_data["image"],
        "digits": test_data["digit"],
        "encoder_params": [torch.nan for _ in range(len(test_data["digit"]))],
    },
}  # only 25M, acceptable

# for split in data.keys():
#     torch.save(data[split], os.path.join(processed_data_dir, f"mnist_{split}.pt"))


# In[39]:


# In[3]:


from torch.utils.data import DataLoader, Dataset


class MNIST_AAE_Dataset(Dataset):
    def __init__(self, processed_data_path, subset_size=None) -> None:
        super().__init__()
        self.processed_data_path = processed_data_path
        self.data = torch.load(processed_data_path)
        if subset_size is not None:
            self.data = {
                "images": self.data["images"][:subset_size],
                "digits": self.data["digits"][:subset_size],
                "encoder_params": self.data["encoder_params"][:subset_size],
            }

    def __len__(self):
        return len(self.data["digits"])

    def __getitem__(self, index):
        encoder_params = self.data["encoder_params"][index]  # can be torch.nan

        return {
            "index": index,  # for setting the encoder_params after training
            "images": self.data["images"][index],
            "digits": self.data["digits"][index],
            "encoder_params": encoder_params,
        }

    def save_encoder_params(self, index, encoder: nn.Module):
        self.data["encoder_params"][index] = encoder.state_dict()

    def save_dataset_to_disk(self, path=None):
        if path is None:
            path = self.processed_data_path

        torch.save(self.data, path)


# In[4]:


train_ds = MNIST_AAE_Dataset(os.path.join(processed_data_dir, "mnist_train.pt"))
train_loader = DataLoader(train_ds, 4, False)
samples = next(iter(train_loader))


# ## Model

# In[ ]:


# ### pennylane amplitude encoding

# In[4]:


def encodelayer(inputs):
    for i in range(n_qubits):
        qml.RY(inputs[:, i], wires=i)
        qml.RZ(inputs[:, i + n_qubits], wires=i)
        qml.RX(inputs[:, i + 2 * n_qubits], wires=i)
        qml.RY(inputs[:, i + 3 * n_qubits], wires=i)


# In[5]:


@qml.qnode(q_device, interface="torch")
@qml.simplify
def quantum_net(inputs, new_weights):
    qml.templates.AmplitudeEmbedding(inputs, wires=range(n_qubits), normalize=True)
    # encodelayer(inputs)
    StronglyEntanglingLayers(new_weights, wires=range(n_qubits))
    return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]


# In[6]:


avg_pool = nn.AvgPool2d(kernel_size=7)
image = next(iter(dataflow["train"]))["image"]
image = avg_pool(image[0]).view(1, -1)
qml.draw_mpl(quantum_net, decimals=1, expansion_strategy="device")(
    image, torch.normal(0, 0.1, size=(num_qlayer, n_qubits, 3))
)


# In[ ]:


# ### AAE

# In[5]:


def RY_layer(weights):
    # weights: (n_qubits, )
    for i in range(weights.shape[0]):
        qml.RY(weights[i], wires=i)


def RX_layer(weights):
    for i in range(weights.shape[0]):
        qml.RX(weights[i], wires=i)


def RZ_layer(weights):
    for i in range(weights.shape[0]):
        qml.RZ(weights[i], wires=i)


def EntanglingLayer(n_qubits, top_type=None):
    if top_type is None:
        for i in range(0, n_qubits - 1, 2):
            qml.CNOT(wires=[i, i + 1])

        for i in range(1, n_qubits - 1, 2):
            qml.CNOT(wires=[i, i + 1])
    elif top_type is not None:
        "top_type = [[0,3], [1,3], [2,3]]"
        for top in top_type:
            qml.CNOT(wires=top)


def HadamardLayer(n_qubits):
    for i in range(n_qubits):
        qml.Hadamard(wires=i)


# In[6]:


# AAE encoder
q_device = qml.device("default.qubit", wires=n_qubits)


@qml.qnode(
    q_device, interface="torch", diff_method="backprop"
)  # qml.state() only supported for backprop
def aae_encoder(inputs=None, weights=None):
    # to use StronglyEntanglingLayers
    # input_weights = torch.zeros(weights.shape + (3,))
    # input_weights[:,:,1] = weights
    # StronglyEntanglingLayers(input_weights, wires=range(n_qubits))  # only rotate along Y axis

    top_type = None
    inputs = torch.tensor(
        [0.0], dtype=torch.float32
    )  # inputs of encoder doesn't exist (or always [1, 0, 0, ...]), but pennylane need this arg
    for l in range(weights.shape[0]):
        RY_layer(weights=weights[l])
        EntanglingLayer(n_qubits, top_type)

    # post selection to deal with negative number in AAE
    # qml.Hadamard(wires=n_qubits-1)
    # qml.measure(wires=n_qubits-1, postselect=1)

    return qml.state()


q_device = qml.device("default.qubit", wires=n_qubits)


@qml.qnode(
    q_device, interface="torch", diff_method="backprop"
)  # qml.state() only supported for backprop
def aae_encoder_hadamard(inputs, weights):
    # to use StronglyEntanglingLayers
    # input_weights = torch.zeros(weights.shape + (3,))
    # input_weights[:,:,1] = weights
    # StronglyEntanglingLayers(input_weights, wires=range(n_qubits))  # only rotate along Y axis

    inputs = torch.tensor(
        [0.0], dtype=torch.float32
    )  # inputs of encoder doesn't exist (or always [1, 0, 0, ...]), but pennylane need this arg
    for l in range(weights.shape[0]):
        RY_layer(weights=weights[l])
        EntanglingLayer(n_qubits)

    # post selection to deal with negative number in AAE
    # qml.Hadamard(wires=n_qubits-1)
    # qml.measure(wires=n_qubits-1, postselect=1)
    HadamardLayer(n_qubits)
    return qml.state()


# In[7]:


# test aae_encoder
rand_param = torch.normal(0, 0.5, (num_encoder_layer, n_qubits))
qml.draw_mpl(aae_encoder, expansion_strategy="device")(None, rand_param)


# In[8]:


# test aae_encoder
aae_encoder(None, rand_param)


# ## Loss

# In[9]:


def MMD(x, y, kernel="rbf"):
    """Emprical maximum mean discrepancy. The lower the result
       the more evidence that distributions are the same.

    Args:
        x: first sample, distribution P
        y: second sample, distribution Q
        kernel: kernel type such as "multiscale" or "rbf"
    """
    xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t())
    rx = xx.diag().unsqueeze(0).expand_as(xx)
    ry = yy.diag().unsqueeze(0).expand_as(yy)

    dxx = rx.t() + rx - 2.0 * xx  # Used for A in (1)
    dyy = ry.t() + ry - 2.0 * yy  # Used for B in (1)
    dxy = rx.t() + ry - 2.0 * zz  # Used for C in (1)

    XX, YY, XY = (
        torch.zeros(xx.shape).to(device),
        torch.zeros(xx.shape).to(device),
        torch.zeros(xx.shape).to(device),
    )

    if kernel == "multiscale":
        bandwidth_range = [0.2, 0.5, 0.9, 1.3]
        for a in bandwidth_range:
            XX += a**2 * (a**2 + dxx) ** -1
            YY += a**2 * (a**2 + dyy) ** -1
            XY += a**2 * (a**2 + dxy) ** -1

    if kernel == "rbf":
        bandwidth_range = [10, 15, 20, 50]
        for a in bandwidth_range:
            XX += torch.exp(-0.5 * dxx / a)
            YY += torch.exp(-0.5 * dyy / a)
            XY += torch.exp(-0.5 * dxy / a)

    return torch.mean(XX + YY - 2.0 * XY)


def tqLoss(result_state, target_state):
    return (
        1 - torch.dot(result_state, target_state).abs() ** 2
    )  # result_state.norm()==1 so dot() is fine


# In[12]:


# MMD example

get_ipython().run_line_magic("matplotlib", "inline")
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import dirichlet, multivariate_normal
from torch.distributions.multivariate_normal import MultivariateNormal

m = 20  # sample size
x_mean = torch.zeros(2) + 1
y_mean = torch.zeros(2)
x_cov = 2 * torch.eye(2)  # IMPORTANT: Covariance matrices must be positive definite
y_cov = 3 * torch.eye(2) - 1

px = MultivariateNormal(x_mean, x_cov)
qy = MultivariateNormal(y_mean, y_cov)
x = px.sample([m]).to(device)
y = qy.sample([m]).to(device)

result = MMD(x, y, kernel="multiscale")

print(f"MMD result of X and Y is {result.item()}")


# In[13]:


# test of MMD
n_sample = 1000
X = torch.randn((n_sample, 2))
Y = torch.randn((n_sample, 2))
criterion = MMD
criterion(X, Y)


# In[14]:


# data in AAE paper
stock = torch.tensor(
    [
        [84.80, 90.10, 88.09, 87.87, 80.55],
        [53.19, 58.20, 57.41, 56.00, 58.75],
        [70.41, 67.03, 65.92, 60.55, 65.73],
        [28.83, 28.50, 28.24, 27.27, 25.92],
    ]
)

# qSVD utils


def log_return(stock):
    return torch.log(stock[:, 1:]) - torch.log(stock[:, :-1])


def norm_return(stock):
    n_s, n_t = stock.shape

    ret = log_return(stock)
    avg_ret = ret.mean(dim=1)
    # std_ret = ret.std(dim=1)
    std_ret = torch.sqrt(((ret - avg_ret) ** 2).mean(dim=1))

    norm_ret = (ret - avg_ret) / std_ret / (n_s * n_t) ** 0.5
    return norm_ret / norm_ret.norm(2)


def corr_mat(stock):
    norm_ret = norm_return(stock)
    corr_matrix = torch.matmul(norm_ret, norm_ret.T)

    return corr_matrix


ret = norm_return(stock)
(ret**2).sum()


# In[18]:


# test tqLoss
samples = next(iter(train_loader))
weight_shapes = {"weights": (num_encoder_layer, n_qubits)}
# Quantum net as a TorchLayer
encoder = qml.qnn.TorchLayer(aae_encoder, weight_shapes)
# sample
loss = tqLoss(encoder(torch.tensor([0.0])), torch.randn((16,)))
# encoder(torch.tensor([0.0]))


# In[19]:


for p in encoder.parameters():
    print(p)


# ## train and test

# In[10]:


def resize_and_norm(image):
    image = F.avg_pool2d(image, kernel_size=(7, 7))  # (1, 1, 28, 28) -> (1, 1, 4, 4)
    image = image.view(-1)  # (1, 1, 4, 4) -> (16, )
    image = image / image.norm(
        2
    )  # normalize the image to norm(2) == 1 as quamtum state did
    return image


def train_encoder(sample, encoder, criterion, optimizer, n_step=100, verbose=True):
    image = sample[
        "images"
    ]  # assume only one image in this sample, image.shape==(1,C,H,W)
    image = resize_and_norm(image)
    _ = torch.zeros(
        (1,), dtype=torch.float32
    )  # inputs doesn't matter, But TorchLayer need it

    # scheduler = CosineAnnealingLR(optimizer, T_max=n_step)

    for t in range(n_step):
        encoding = encoder(_)
        loss = criterion(encoding, image)
        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
        # scheduler.step()

        if verbose:
            if t % 10 == 0:
                print(f"loss: {loss.item()}", end="\r")


# In[11]:


# criterion = tqLoss
criterion = nn.MSELoss()

weight_shapes = {"weights": (num_encoder_layer, n_qubits)}


# Quantum net as a TorchLayer
encoder = qml.qnn.TorchLayer(aae_encoder, weight_shapes, init_method=nn.init.uniform_)

# Optimizer
optimizer = optim.Adam(encoder.parameters(), lr=1e-2)


train_loader = DataLoader(train_ds, 1, False)
sample = next(iter(train_loader))


# In[12]:


train_encoder(sample, encoder, criterion, optimizer, n_step=100)

# cpu: AMD R7 5700U, gpu: None
# log: MSELoss, 100 steps, 5.6s loss: 0.00013334548566490412
# very fast converge
# tqLoss converge slightly faster than MSELoss
# how about using 16,16 data? take 8 qubits, still fast enough?


# In[13]:


sample_index = sample["index"].item()
train_ds.save_encoder_params(sample_index, encoder)
print(train_ds.data["encoder_params"][sample_index])


# In[14]:


print(resize_and_norm(sample["images"]))
print(encoder(torch.tensor(0.0)))


# In[15]:


fig, axes = plt.subplots(1, 2)
axes[0].imshow(resize_and_norm(sample["images"]).view(1, 4, 4).permute(1, 2, 0))
axes[1].imshow(encoder(torch.tensor(0.0)).detach().view(1, 4, 4).permute(1, 2, 0))


# In[ ]:


from tqdm import tqdm


def train_encoders_for_dataset(dataset, criterion, optimizer, n_step=100):
    loader = DataLoader(dataset, 1, False)

    for sample in tqdm(loader, leave=False):
        weight_shapes = {"weights": (num_encoder_layer, n_qubits)}
        encoder = qml.qnn.TorchLayer(
            aae_encoder, weight_shapes, init_method=nn.init.uniform_
        )
        train_encoder(sample, encoder, criterion, optimizer, n_step, verbose=False)

        sample_index = sample["index"].item()
        dataset.save_encoder_params(sample_index, encoder)

    dataset.save_dataset_to_disk()
    print("Encoder params saved to disk")

    return dataset


# In[ ]:


# encoder.load_state_dict()
# encoder.requires_grad_(False)


# In[ ]:


# ## Results
#
#

# **pennylane amplitude encoding**
#
# Epoch 1:
# 0.005 0.8584052324295044
# valid set accuracy: 0.6730290456431536
# valid set loss: 0.9551709294319153
#
# Epoch 2:
# 0.0048776412907378846457
# valid set accuracy: 0.8929460580912864
# valid set loss: 0.7974672913551331
#
# Epoch 3:
# 0.0045225424859373685027
# valid set accuracy: 0.9352697095435685
# valid set loss: 0.753818690776825
#
# Epoch 4:
# 0.0039694631307311836213
# valid set accuracy: 0.9485477178423236
# valid set loss: 0.7371468544006348
#
# Epoch 5:
# 0.0032725424859373687498
# valid set accuracy: 0.9609958506224067
# valid set loss: 0.727253794670105
#
# Epoch 6:
# 0.00250.6985626816749573
# valid set accuracy: 0.9643153526970955
# valid set loss: 0.721916913986206
#
# Epoch 7:
# 0.0017274575140626316482
# valid set accuracy: 0.970954356846473
# valid set loss: 0.7195770144462585
#
# Epoch 8:
# 0.0010305368692688174391
# valid set accuracy: 0.9701244813278008
# valid set loss: 0.7184303998947144
#
# Epoch 9:
# 0.0004774575140626316366
# valid set accuracy: 0.9734439834024896
# valid set loss: 0.7179780006408691
#
# Epoch 10:
# 0.0001223587092621161729
# valid set accuracy: 0.9742738589211618
# valid set loss: 0.7178515195846558
#
# test set accuracy: 0.9715447154471545
# test set loss: 0.7131120562553406
#

# **pennylane angle ecoding**
#
# Epoch 1:
# 0.005 1.0613005161285421
# valid set accuracy: 0.8721991701244813
# valid set loss: 0.9659674167633057
#
# Epoch 2:
# 0.0048776412907378846388
# valid set accuracy: 0.9294605809128631
# valid set loss: 0.7519037127494812
#
# Epoch 3:
# 0.0045225424859373685668
# valid set accuracy: 0.9427385892116182
# valid set loss: 0.6945566534996033
#
# Epoch 4:
# 0.0039694631307311838341
# valid set accuracy: 0.9477178423236514
# valid set loss: 0.6593191027641296
#
# Epoch 5:
# 0.0032725424859373687941
# valid set accuracy: 0.941908713692946
# valid set loss: 0.6438149809837341
#
# Epoch 6:
# 0.00250.6027146577835083
# valid set accuracy: 0.9394190871369295
# valid set loss: 0.637258768081665
#
# Epoch 7:
# 0.0017274575140626316727
# valid set accuracy: 0.9402489626556016
# valid set loss: 0.6341438889503479
#
# Epoch 8:
# 0.0010305368692688174798
# valid set accuracy: 0.9402489626556016
# valid set loss: 0.63267982006073
#
# Epoch 9:
# 0.0004774575140626316303
# valid set accuracy: 0.9410788381742738
# valid set loss: 0.6321021318435669
#
# Epoch 10:
# 0.0001223587092621161779
# valid set accuracy: 0.9410788381742738
# valid set loss: 0.6319551467895508
#
# test set accuracy: 0.9242886178861789
# test set loss: 0.639994740486145
#

# In[ ]:


# In[ ]:
