import logging
import os
import sys
import time

import pennylane as qml
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utils.aae_dataset import MNIST_AAE_Dataset
from models.state_generators import StateGenerator
from utils import append_log, resize_and_norm, seed_everything

if len(sys.argv) != 2:
    print(f"Usage: python {sys.argv[0]} <config-path> e.g., configs/train_qnn.yaml")
    sys.exit(1)

config_path = sys.argv[1]
# config_path = r"./configs/train_qnn.yaml"
# config_path = r"./configs/train_qnn_6qubits.yaml"
# config_path = r"./configs/train_qnn_8qubits.yaml"

ENCODERS = {"SuperEncoder": StateGenerator, "AM": None, "AG": None, "AAE": None}


def load_data(config):
    train_path = os.path.join(config.data.root, "mnist_train.pt")
    val_path = os.path.join(config.data.root, "mnist_valid.pt")
    test_path = os.path.join(config.data.root, "mnist_test.pt")

    # Dataset
    train_ds = MNIST_AAE_Dataset(train_path)
    val_ds = MNIST_AAE_Dataset(val_path)
    test_ds = MNIST_AAE_Dataset(test_path)

    train_loader = DataLoader(train_ds, shuffle=True, **config.data.dataloader)
    val_loader = DataLoader(val_ds, shuffle=False, **config.data.dataloader)
    test_loader = DataLoader(test_ds, shuffle=False, **config.data.dataloader)

    return train_loader, val_loader, test_loader


class QML_Pipeline(nn.Module):
    """
    inputs -> Encoder -> state -> QNN Ansatz -> outputs
    """

    def __init__(self, config) -> None:
        super().__init__()
        self.config = config
        self.n_qubits = config.encoder.state_generator.aae_encoder.n_qubits
        q_device = self.config.encoder.state_generator.aae_encoder.q_device
        self.encoder = ENCODERS[config.encoder.type](config.encoder)
        self.encoder.load(config.encoder.save_path, strict=False)

        # @qml.qnode(self.config.encoder.state_generator.aae_encoder.q_device, interface="torch")
        @qml.qnode(device=qml.device(q_device, wires=self.n_qubits), interface="torch")
        @qml.simplify
        def ansatz(inputs, weights):
            qml.AmplitudeEmbedding(
                inputs, wires=range(self.n_qubits)
            )  # pennylane need this to input non-zero state
            qml.StronglyEntanglingLayers(weights, wires=range(self.n_qubits))
            return [qml.expval(qml.PauliZ(i)) for i in range(self.n_qubits)]

        self.ansatz_weight_shapes = {
            "weights": (config.ansatz.n_ansatz_layers, self.n_qubits, 3)
        }
        self.ansatz = qml.qnn.TorchLayer(
            ansatz, self.ansatz_weight_shapes, init_method=nn.init.uniform_
        )

    def forward(self, inputs):
        # some pennylane template can only accept state one at a time?
        states, _ = self.encoder(inputs, True)

        # FIXME: these are just for runtime measurement
        #  delete them when actually training the qnn model
        start_time = time.perf_counter()
        pred = self.ansatz(states)
        end_time = time.perf_counter()
        print(f"forward time: {end_time - start_time}")
        return pred


def train_qnn(config):
    train_loader, val_loader, test_loader = load_data(config)

    pipeline = QML_Pipeline(config)
    criterion = nn.NLLLoss()

    optimizer = torch.optim.Adam(
        pipeline.ansatz.parameters(), lr=5e-3, weight_decay=1e-4
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=config.n_epochs
    )

    os.makedirs(os.path.dirname(config.logs), exist_ok=True)
    logging.basicConfig(filename=config.logs, encoding="utf-8", level=logging.INFO)
    for epoch in range(config.n_epochs):
        print("Epoch: ", epoch)
        loss_epoch = train(config, train_loader, pipeline, criterion, optimizer)
        logging.info(f"Epoch: {epoch} -- Train_loss: {loss_epoch}")

        loss_epoch_val, acc_val = valid_test(
            config, val_loader, pipeline, criterion, "Val"
        )
        logging.info(
            f"Epoch: {epoch} -- Val_loss: {loss_epoch_val} -- Val_acc: {acc_val}"
        )

        scheduler.step()

    loss_test, acc_test = valid_test(config, test_loader, pipeline, criterion, "Test")
    logging.info(f"Epoch: {epoch} -- Test_loss: {loss_test} -- Test_acc: {acc_test}")


def train(config, dataloader, model, criterion, optimizer):
    total_loss = 0
    with tqdm(dataloader) as bar:
        for batch_index, batch in enumerate(bar):
            inputs = resize_and_norm(batch["images"], model.n_qubits).to(config.device)
            targets = batch["digits"].to(config.device)
            # calculate gradients via back propagation
            # calculate the forward time
            prediction = model(inputs)
            prediction = F.log_softmax(prediction, dim=1)
            loss = criterion(prediction, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            bar.set_postfix(loss=total_loss / (batch_index + 1))
            bar.set_postfix({"Loss": loss.item()})
            # print(f"loss: {loss.item()}", end='\r')
    return total_loss / (batch_index + 1)


def valid_test(config, dataloader, model, criterion, split="val"):
    target_all = []
    output_all = []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            inputs = resize_and_norm(batch["images"], model.n_qubits).to(config.device)
            targets = batch["digits"].to(config.device)

            prediction = model(inputs)
            prediction = F.log_softmax(prediction, dim=1)

            target_all.append(targets)
            output_all.append(prediction)

        target_all = torch.cat(target_all, dim=0)
        output_all = torch.cat(output_all, dim=0)

    _, indices = output_all.topk(1, dim=1)
    masks = indices.eq(target_all.view(-1, 1).expand_as(indices))
    size = target_all.shape[0]
    corrects = masks.sum().item()
    accuracy = corrects / size
    loss = criterion(output_all, target_all).item()
    print(f"{split}_loss: {loss} -- {split}_acc: {accuracy}")

    return loss, accuracy


if __name__ == "__main__":
    OmegaConf.register_new_resolver("eval", eval)
    config = OmegaConf.load(config_path)
    seed_everything(config.seed)
    train_qnn(config)
