"""
Usage: python $0 <config-file-path>
"""

import sys

sys.path.append(".")

import os

import numpy as np
import pennylane as qml
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utils.aae_dataset import FractalDB_Dataset, MNIST_AAE_Dataset
from models.state_generators import AAE_StateGenerator, StateGenerator
from utils import norm_image, resize, resize_and_norm, seed_everything, visual_compare

# In[2]:


folder_path = "./eval_datasets/"
config_file_path = sys.argv[1]
try:
    OmegaConf.register_new_resolver("eval", eval)
except ValueError:
    pass
config = OmegaConf.load(config_file_path)
print(OmegaConf.to_yaml(config))

seed_everything(config.seed)


# ## Load Synthetic Dataset

# In[3]:


def norm_data(images):
    images = images.reshape(images.shape[0], -1)
    norms = torch.norm(images, p=2, dim=1, keepdim=True)
    images = images / norms
    return images


num_qubits = config.state_generator.aae_encoder.n_qubits
file_names = [
    f"beta_a1b1_{num_qubits}-qubits.pt",
    f"exponential_rate1_{num_qubits}-qubits.pt",
    f"lognormal_mean0std1_{num_qubits}-qubits.pt",
    f"normal_mean0.3std0.5_{num_qubits}-qubits.pt",
    f"uniform_low0high1_{num_qubits}-qubits.pt",
]
distribution_files = {}

for file_name in file_names:
    distribution_name = file_name.split(".pt")[-2]
    distribution_files[distribution_name] = folder_path + file_name


n_samples = 32
# normed_data = np.zeros((n_samples*16,len(distribution_files)))
test_loaders = {}
for idx, name in enumerate(distribution_files.keys()):
    test_dataset = MNIST_AAE_Dataset(distribution_files[name])
    test_loader = DataLoader(
        test_dataset, shuffle=True, batch_size=n_samples, num_workers=0, pin_memory=True
    )
    test_loaders[name] = test_loader


aae_state_generator = AAE_StateGenerator(config)
print(aae_state_generator)
results = {}
targets = {}
for key, loader in test_loaders.items():
    results[key] = []
    targets[key] = []
    print(f"Testing Dataset: {key}")
    samples = next(iter(loader))
    target_states = samples["images"]
    for target_state in tqdm(target_states):
        target_state = norm_data(target_state).to(config.device)
        result_state = aae_state_generator(target_state)
        results[key].append(result_state.detach().numpy())
        targets[key].append(target_state.detach().numpy())
    results[key] = np.vstack(results[key])
    targets[key] = np.vstack(targets[key])


fidelities = {}
for key in results.keys():
    fidelities[key] = qml.math.fidelity_statevector(results[key], targets[key]).mean()

print(aae_state_generator)
qml.specs(aae_state_generator.aae_encoder)(samples["images"][0])
print(fidelities)


import time

aae_state_generator = AAE_StateGenerator(config)
print(aae_state_generator)
results = {}
targets = {}

durations = []
for key, loader in test_loaders.items():
    results[key] = []
    targets[key] = []
    print(f"Testing Dataset: {key}")
    samples = next(iter(loader))
    target_states = samples["images"]
    for target_state in tqdm(target_states[:8]):
        target_state = norm_data(target_state).to(config.device)
        start_time = time.perf_counter()
        result_state = aae_state_generator(target_state)
        end_time = time.perf_counter()
        durations.append(end_time - start_time)
    break
print(sum(durations) / len(durations))


import time

aae_state_generator = AAE_StateGenerator(config)
print(aae_state_generator)
results = {}
targets = {}

durations = []
for key, loader in test_loaders.items():
    results[key] = []
    targets[key] = []
    print(f"Testing Dataset: {key}")
    samples = next(iter(loader))
    target_states = samples["images"]
    for target_state in tqdm(target_states[:8]):
        target_state = norm_data(target_state).to(config.device)
        start_time = time.perf_counter()
        result_state = aae_state_generator.compute_state()
        end_time = time.perf_counter()
        durations.append(end_time - start_time)
    break
print(sum(durations) / len(durations))
