#!/usr/bin/env python
# coding: utf-8

# In[1]:


import sys

sys.path.append(".")

import os
import time

import pennylane as qml
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from torch.utils.data import DataLoader

from data_utils.aae_dataset import FractalDB_Dataset, MNIST_AAE_Dataset
from models.state_generators import StateGenerator
from utils import norm_image, resize, resize_and_norm, seed_everything, visual_compare

# In[2]:


config_file_path = sys.argv[1]

OmegaConf.register_new_resolver("eval", eval)
config = OmegaConf.load(config_file_path)
print(OmegaConf.to_yaml(config))

seed_everything(config.seed)


# ## Load Model

# In[3]:

num_qubits = config.state_generator.aae_encoder.n_qubits

superEncoder = StateGenerator(config=config).to(config.device)

superEncoder.load(config.checkpoint.save_path, config.device, strict=False)


# ## Compute state fidelity

# In[4]:


n_samples = 256
generator = torch.Generator(device=config.device)
generator.manual_seed(config.seed)


# In[5]:


# FIXME: since the datasets have changed, we do not traverse the folder
# def get_file_names_in_folder(folder_path):
#     file_names = []
#     for root, dirs, files in os.walk(folder_path):
#         for file in files:
#             file_names.append(file)
#     return file_names


folder_path = "./eval_datasets/"
# file_names = get_file_names_in_folder(folder_path)
# Hard-code the file names
file_names = [
    f"beta_a1b1_{num_qubits}-qubits.pt",
    f"exponential_rate1_{num_qubits}-qubits.pt",
    f"lognormal_mean0std1_{num_qubits}-qubits.pt",
    f"normal_mean0.3std0.5_{num_qubits}-qubits.pt",
    f"uniform_low0high1_{num_qubits}-qubits.pt",
]


# In[6]:


distribution_files = {}

for file_name in file_names:
    distribution_name = file_name.split(".pt")[-2]
    distribution_files[distribution_name] = folder_path + file_name

for distribution_name, file_name in distribution_files.items():
    print(f"{distribution_name}: {file_name}")


# In[7]:


import csv

import numpy as np


def norm_data(images):
    images = images.reshape(images.shape[0], -1)
    norms = torch.norm(images, p=2, dim=1, keepdim=True)
    images = images / norms
    return images


n_samples = 256
# normed_data = np.zeros((n_samples*16,len(distribution_files)))

for idx, name in enumerate(distribution_files.keys()):
    test_dataset = MNIST_AAE_Dataset(distribution_files[name])
    test_loader = DataLoader(
        test_dataset, shuffle=True, batch_size=n_samples, num_workers=0, pin_memory=True
    )

    test_samples = next(iter(test_loader))

    target_state = norm_data(test_samples["images"]).to(config.device)
    # normed_data[:, idx] = target_state.cpu().numpy().flatten()
    encoder_params = superEncoder(target_state)
    start_time = time.perf_counter()
    result_state = superEncoder.qc(encoder_params)
    end_time = time.perf_counter()
    print(f"runtime: {end_time - start_time}")

    # Get abstract value of both result and target since we cannot encode values smaller than 0
    # result_state = np.abs(result_state.detach().numpy())
    # target_state = np.abs(target_state)
    fidelity = qml.math.fidelity_statevector(result_state, target_state)
    print(
        f"average fidelity of {n_samples} states on {name} distribution : {fidelity.mean():.4f}"
    )


# In[ ]:
