#!/usr/bin/env python
# coding: utf-8

# In[66]:


# %cd ..
import sys

sys.path.append(".")

import os

import pennylane as qml
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from torch.utils.data import DataLoader

from data_utils.aae_dataset import FractalDB_Dataset, MNIST_AAE_Dataset
from models.state_generators import StateGenerator
from utils import norm_image, resize, resize_and_norm, seed_everything, visual_compare

# In[67]:


# version = "v0.0.8.0.yaml"
version = "v0.0.7.16.3.yaml"

config_dir = r"./configs/"

try:
    OmegaConf.register_new_resolver("eval", eval)
except ValueError:
    pass
config = OmegaConf.load(os.path.join(config_dir, version))
print(OmegaConf.to_yaml(config))

seed_everything(config.seed)


# ## Load Data

# In[68]:


from torchvision.transforms.functional import to_pil_image

# FractalDB Data
fractal_dataset = FractalDB_Dataset(**config.dataset)
fractal_loader = DataLoader(fractal_dataset, shuffle=True, **config.dataloader)


# In[69]:


fractal_samples = next(iter(fractal_loader))
print(resize_and_norm(fractal_samples["images"][1], 4))


# In[70]:


import matplotlib.pylab as plt

plt.imshow(resize_and_norm(fractal_samples["images"][1], 8).view(16, 16))


# In[71]:


to_pil_image(fractal_samples["images"][1][0])


# ## Load Model

# In[72]:


superEncoder = StateGenerator(config=config).to(config.device)

superEncoder.load(config.checkpoint.save_path, config.device, strict=True)
n_qubits = config.state_generator.aae_encoder.n_qubits


# ## Compute state fidelity on:
# - Arbitrary State
# - PDE State
# - QML State
# - QEC State
#
#
# **consider use more samples, result unstable**

# In[73]:


n_samples = 256
generator = torch.Generator(device=config.device)
generator.manual_seed(config.seed)


# ## Fidelity on Training set
#
#
# - maybe add some noise to training sample? because the pixel values of the training samples are too typical.

# In[74]:


n_sampels = 256

fractal_samples = next(iter(fractal_loader))
# only for real part of the state
target_state = resize_and_norm(
    fractal_samples["images"], config.state_generator.aae_encoder.n_qubits
).to(config.device)
encoder_params = superEncoder(target_state)
result_state = superEncoder.qc(encoder_params)
fidelity = qml.math.fidelity_statevector(result_state, target_state)

# print(fidelity)
print(f"average fidelity of {n_samples} states: {fidelity.mean():.4f}")

print("-" * 20)
print("4 examples: ")
visual_compare(result_state[:4], target_state[:4])


# ## Arbitrary State (uniformly sampled)

# In[75]:


# only for real part of the state
target_state = norm_image(
    torch.rand(
        (n_samples, config.state_generator.super_encoder.in_dim),
        generator=generator,
        device=config.device,
    )
)
encoder_params = superEncoder(target_state)
result_state = superEncoder.qc(encoder_params)
fidelity = qml.math.fidelity_statevector(result_state, target_state)

# print(fidelity)
print(f"average fidelity of {n_samples} states: {fidelity.mean():.4f}")

print("-" * 20)
print("4 examples: ")
visual_compare(result_state[:4], target_state[:4])


# ## QML State

# In[78]:


# MNIST samples
mnist_dir = r"./mnist/processed"
test_ds = MNIST_AAE_Dataset(os.path.join(mnist_dir, "mnist_test.pt"))
test_loader = DataLoader(test_ds, batch_size=n_samples, shuffle=True)

samples = next(iter(test_loader))
# print(samples.keys())

target_state = resize_and_norm(
    samples["images"], config.state_generator.aae_encoder.n_qubits
).to(config.device)
encoder_params = superEncoder(target_state)
result_state = superEncoder.qc(encoder_params)
fidelity = qml.math.fidelity_statevector(result_state, target_state)

# print(fidelity)
print(f"average fidelity of {n_samples} states: {fidelity.mean():.4f}")

print("-" * 20)
print("4 examples: ")
visual_compare(result_state[:4], target_state[:4])


# ## Multiple Synthetic Test Dataset

# In[77]:


from utils import get_test_loaders

test_loaders = get_test_loaders(
    "./eval_datasets/", n_qubits, n_samples_per_ds=n_sampels
)

results = {}
targets = {}
for key, loader in test_loaders.items():
    print(f"Testing Dataset: {key}")
    samples = next(iter(loader))
    target_states = samples["images"]
    target_states = resize_and_norm(target_states, n_qubits).to(config.device)
    encoder_params = superEncoder(target_states)
    result_states = superEncoder.qc(encoder_params)
    results[key] = result_states
    targets[key] = target_states

fidelities = {}
for key in results.keys():
    fidelities[key] = qml.math.fidelity_statevector(results[key], targets[key]).mean()

for key, value in fidelities.items():
    print(f"{key}: {value.item()}")

print(
    "Mean Fidelity:",
    (sum([value for value in fidelities.values()]) / len(fidelities)).item(),
)


# In[ ]:
