"""
This script evaluates the performance of superencoder trained on MNIST dataset and evaluate
fidelity on MNIST test set,
serving as a motivation for enhancing loss function design.
"""

import sys

sys.path.append(".")

import os

import pennylane as qml
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from torch.utils.data import DataLoader

from data_utils.aae_dataset import FractalDB_Dataset, MNIST_AAE_Dataset
from models.state_generators import StateGenerator
from utils import norm_image, resize, resize_and_norm, seed_everything, visual_compare

# In[2]:


# version = "v0.0.7.11.yaml" # for state MSE, res = 0.9873
version = "Fidelity_landscape.yaml"  # for state fidelity, res = 0.9908
config_dir = r"./configs/"

OmegaConf.register_new_resolver("eval", eval)
config = OmegaConf.load(os.path.join(config_dir, version))
print(OmegaConf.to_yaml(config))

seed_everything(config.seed)


# ## Load Model

# In[3]:


super_encoder = StateGenerator(config=config).to(config.device)

super_encoder.load(config.checkpoint.save_path, config.device, False)


# ## Compute state fidelity

# In[4]:


n_samples = 256
generator = torch.Generator(device=config.device)
generator.manual_seed(config.seed)

import csv

import numpy as np


def norm_data(images):
    images = images.reshape(images.shape[0], -1)
    norms = torch.norm(images, p=2, dim=1, keepdim=True)
    images = images / norms
    return images


n_samples = 256
# normed_data = np.zeros((n_samples*16,len(distribution_files)))

test_dataset = MNIST_AAE_Dataset("./mnist/processed/mnist_test.pt")
test_loader = DataLoader(
    test_dataset, shuffle=True, batch_size=n_samples, num_workers=0, pin_memory=True
)

test_samples = next(iter(test_loader))

target_state = resize(
    test_samples["images"], config.state_generator.aae_encoder.n_qubits
).to(config.device)
target_state = norm_image(target_state)
# normed_data[:, idx] = target_state.cpu().numpy().flatten()
encoder_params = super_encoder(target_state)
result_state = super_encoder.qc(encoder_params)
fidelity = qml.math.fidelity_statevector(result_state, target_state)
print(f"average fidelity of {n_samples} states on mnist : {fidelity.mean():.4f}")


# In[ ]:
