import sys
sys.path.append("../")

import os
import time
from tqdm import tqdm
from pprint import pprint
import json

import omegaconf
from torch.utils.data import DataLoader

import pennylane as qml


from data_utils.aae_dataset import MNIST_AAE_Dataset
from models.state_generators import AAE_StateGenerator
from utils import resize_and_norm, seed_everything

from argparse import ArgumentParser

def eval_aae_fidelity(target_state, aae_config):
    aae = AAE_StateGenerator(aae_config)
    start = time.perf_counter()
    result_state = aae(target_state)
    duration = time.perf_counter() - start
    fidelity = qml.math.fidelity_statevector(result_state, target_state).item()
    return fidelity, duration
    

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("n_qubits", type=int, help="Number of qubits of AAE to be eval")
    parser.add_argument("start_num_layers", type=int)
    parser.add_argument("end_num_layers", type=int)
    args = parser.parse_args()
    
    assert args.start_num_layers >= 1
    assert args.start_num_layers <= args.end_num_layers, "start_num_layers must be smaller or equal to end_num_layers"
    n_qubits = args.n_qubits
    encoder_layers_range = [args.start_num_layers, args.end_num_layers]


    config_path = rf"../configs/AAE_encoder_{n_qubits}qubits.yaml"
    config = omegaconf.OmegaConf.load(config_path)
    print(omegaconf.OmegaConf.to_yaml(config))
    seed_everything(config.seed)

    # get a MNIST sample
    mnist_dir = r"../mnist/processed"
    test_ds = MNIST_AAE_Dataset(os.path.join(mnist_dir, "mnist_test.pt"))
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=True)

    samples = next(iter(test_loader))

    target_state = resize_and_norm(samples["images"], config.state_generator.aae_encoder.n_qubits).to(config.device)
    # print(target_state.shape)
    assert target_state.shape[-1] == 2**n_qubits

    # size = int((2**n_qubits)**0.5)
    # plt.imshow(target_state.view(size, size))
    # plt.title(samples["digits"].item())

    results = []
    for num_encoder_layers in tqdm(range(encoder_layers_range[0], encoder_layers_range[1]+1)):
        config.state_generator.aae_encoder.n_encoder_layers = num_encoder_layers
        fidelity, duration = eval_aae_fidelity(target_state, config)
        results.append(
            {
                "n_qubits": n_qubits,
                "num_encoder_layers": num_encoder_layers,
                "fidelity": fidelity,
                "duration": duration
            }
        )

    pprint(results)


    log_dir = r"../logs/eval/aae_depth/"
    file_path = os.path.join(log_dir, f"{n_qubits}qubits.json")
    os.makedirs(log_dir, exist_ok=True)
    with open(file_path, "w") as f:
        json.dump(results, f, indent=4)

    print(f"Results save to {file_path}")