import sys

sys.path.append(".")

import os
import re

import pennylane as qml
import torch
from omegaconf import OmegaConf
from qiskit import QuantumCircuit
from torch.utils.data import DataLoader

from data_utils.aae_dataset import FractalDB_Dataset, MNIST_AAE_Dataset
from models.batch_encoders import aae_encoder_for_train
from models.state_generators import AAE_StateGenerator, StateGenerator
from utils import norm_image, resize, resize_and_norm, seed_everything, visual_compare

# version = "v0.0.7.yaml" # for superencoder qasm generation
version = "AAE_encoder.yaml"
config_dir = r"./configs/"

OmegaConf.register_new_resolver("eval", eval)
config = OmegaConf.load(os.path.join(config_dir, version))
print(OmegaConf.to_yaml(config))
seed_everything(config.seed)


def get_file_names_in_folder(folder_path):
    file_names = []
    for root, dirs, files in os.walk(folder_path):
        for file in files:
            # only 4 qubit
            if file.endswith("4-qubits.pt"):
                file_names.append(file)
    return file_names


def preprocess_eval_dataset():
    n_samples = 256
    generator = torch.Generator(device=config.device)
    generator.manual_seed(config.seed)

    folder_path = "./eval_datasets/"
    file_names = get_file_names_in_folder(folder_path)

    distribution_files = {}

    for file_name in file_names:
        distribution_name = file_name.split(".pt")[-2]
        distribution_files[distribution_name] = folder_path + file_name

    for distribution_name, file_name in distribution_files.items():
        print(f"{distribution_name}: {file_name}")

    return distribution_files


def norm_data(images):
    images = images.reshape(images.shape[0], -1)
    norms = torch.norm(images, p=2, dim=1, keepdim=True)
    images = images / norms
    return images


def save_qasm(qnode: qml.QNode, file_name: str):
    qasm_str = qnode.qtape.to_openqasm()

    # remove measurements and classical registers
    # otherwise it will cause error for qiskit transpilation in HHL
    qasm_str = re.sub(
        r"tensor\(\[0?(\d+(\.\d+)?)\], grad_fn=<RemainderBackward0>\)", r"\1", qasm_str
    )
    qasm_str = re.sub(r"measure .*\n", "", qasm_str)
    qasm_str = re.sub(r"creg .*\n", "", qasm_str)

    file_path = os.path.join("./data_hhl", file_name)
    with open(file_path, "w") as f:
        f.write(qasm_str)


def save_state(targ_state: torch.Tensor, file_name):
    targ_state_list = targ_state.tolist()
    file_path = os.path.join("./data_hhl", file_name)
    with open(file_path, "w") as f:
        f.write(str(targ_state_list))


# FIXME: this function assumes the state vectors has already been generated
# in `qasm/` dir
def generate_aae_qsp_qasm(sv_file_name):
    aae_state_gen = AAE_StateGenerator(config)

    # get target state vector
    with open(os.path.join("./data_hhl/", sv_file_name), "r") as f:
        sv_list = eval(f.read())

    target_sv = torch.Tensor(sv_list).view(1, -1)
    result_sv = aae_state_gen(target_sv)
    print(qml.math.fidelity_statevector(target_sv, result_sv))
    print(result_sv)

    idx = sv_file_name.rfind("_")
    name_prefix = sv_file_name[:idx]
    # save_qasm(aae_state_gen.aae_encoder, f"aae_{name}.qasm")
    save_state(result_sv.view(-1), f"{name_prefix}_aae.txt")

    # TODO: delete below block
    #   this block serves as a verification that transformation to qiskit is successfull
    ## ----------- temp -------------- ##
    # qc = QuantumCircuit.from_qasm_file(os.path.join("./data_hhl", f"aae_{name}.qasm"))
    # #qc.save_state()
    # from qiskit_aer import Aer
    # backend = Aer.get_backend("statevector_simulator")
    # sv = backend.run(qc).result().get_statevector()
    # print(sv)
    ## ----------- temp -------------- ##


def generate_superencoder_qsp_qasm(distribution_files):
    n_encoder_layers = config.state_generator.aae_encoder.n_encoder_layers
    n_qubits = config.state_generator.aae_encoder.n_qubits

    # load super encoder
    super_encoder = StateGenerator(config=config).to(config.device)
    super_encoder.load(config.checkpoint.save_path, config.device, False)

    for idx, name in enumerate(distribution_files.keys()):
        test_dataset = MNIST_AAE_Dataset(distribution_files[name])
        test_loader = DataLoader(
            test_dataset, shuffle=True, batch_size=1, num_workers=0, pin_memory=True
        )

        test_samples = next(iter(test_loader))

        target_state = norm_data(test_samples["images"]).to(config.device)
        # normed_data[:, idx] = target_state.cpu().numpy().flatten()

        params = super_encoder(target_state)
        print(f"predicted parameters from superencoder: {params}")

        @qml.qnode(
            qml.device(
                config.state_generator.aae_encoder.q_device,
                wires=config.state_generator.aae_encoder.n_qubits,
            ),
            interface="torch",
            diff_method="backprop",
        )
        def aae_encoder():
            aae_encoder_for_train(params, n_encoder_layers, n_qubits, to_float=True)
            return qml.state()

        aae_encoder()  # call qnode, otherwise it will not be instantiated
        name_prefix = f"{name}"
        # save_qasm(aae_encoder, f"{name_prefix}.qasm")
        save_state(target_state.view(-1), f"{name_prefix}_orig.txt")

        # TODO: delete below block
        #   this block serves as a verification that transformation to qiskit is successfull
        ## ----------- temp -------------- ##
        # qc = QuantumCircuit.from_qasm_file(os.path.join("./data_hhl", f"{name}_superencoder.qasm"))
        # #qc.save_state()
        # from qiskit_aer import Aer
        # backend = Aer.get_backend("statevector_simulator")
        # sv = backend.run(qc).result().get_statevector()
        # print(sv)
        ## ----------- temp -------------- ##

        result_state = super_encoder.compute_state(target_state)
        print(result_state)
        save_state(result_state.view(-1), f"{name_prefix}_superencoder.txt")
        fidelity = qml.math.fidelity_statevector(result_state, target_state)
        print(f"fidelity: {fidelity}")


def main():
    if len(sys.argv) == 2:
        qsp_type = sys.argv[1]
    elif len(sys.argv) == 1:
        qsp_type = "superencoder"
    else:
        print(f"Usage: {sys.argv[0]} <qsp-type: superencoder or AAE>")
        sys.exit(1)

    dist_files = preprocess_eval_dataset()

    if qsp_type == "superencoder":
        generate_superencoder_qsp_qasm(dist_files)
    elif qsp_type == "AAE":
        sv_files = [f for f in os.listdir("./data_hhl") if f.endswith("_orig.txt")]
        for sv in sv_files:
            generate_aae_qsp_qasm(sv)
        # generate_aae_qsp_qasm("exponential_rate1_superencoder.txt")
    else:
        raise NotImplementedError("Unsupported QSP type")


if __name__ == "__main__":
    main()
