import argparse
import math
import os
from dataclasses import dataclass

import pennylane as qml
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utils.aae_dataset import MNIST_AAE_Dataset
from models.batch_encoders import aae_encoder, aae_encoder_for_train
from models.superencoders import MLP
from utils import append_log, resize_and_norm


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "mode", type=str, help="train | inference"
    )  # test mode use trained model to pred encoder params for all dataset in data_dir

    parser.add_argument("--seed", type=int, default=42)

    # model args
    parser.add_argument(
        "--super_encoder_type",
        type=str,
        default="mlp",
        help="mlp| ...(not support yet)",
    )
    parser.add_argument("--version", type=str, default="v0.0.1")
    parser.add_argument(
        "--n_qubits", type=int, default=4, help="number of qubits to use"
    )
    parser.add_argument(
        "--n_encoder_layers",
        type=int,
        default=8,
        help="number of layers to use in aae encoder",
    )
    parser.add_argument(
        "--n_ansatz_layers",
        type=int,
        default=5,
        help="number of layers to use in ansatz",
    )

    # train args
    parser.add_argument(
        "--digits_used",
        type=str,
        default=r"digits[3, 6]",
        help="mnist digits used for training",
    )
    parser.add_argument(
        "--n_epochs", type=int, default=10, help="number of epochs to train"
    )
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--learning_rate", type=float, default=5e-3)
    parser.add_argument("--weight_decay", type=float, default=1e-4)
    parser.add_argument("--num_workers", type=int, default=0)
    parser.add_argument("--pin_memory", type=bool, default=False)

    args = parser.parse_args()

    # additional info
    args.data_dir: str = rf"./mnist/processed/{args.digits_used}"
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.logs: str = rf"./logs/e2e/{args.digits_used}"
    args.save_path: str = rf"./trained_models/e2e/{args.digits_used}/{args.version}_{args.super_encoder_type}.pt"

    return args


def seed_everything(seed):
    import random

    import numpy as np

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


def get_super_encoder(config):
    image_size = math.floor((2**config.n_qubits) ** 0.5)
    image_channel = 1

    if config.super_encoder_type == "mlp":
        super_encoder = MLP(
            image_size * image_size * image_channel,
            config.n_qubits * config.n_encoder_layers,
        ).to(config.device)
    else:
        raise NotImplementedError(
            f"Unsupported super encoder type: {config.super_encoder_type}"
        )

    return super_encoder


def get_inputs(batch, config):
    if config.n_qubits == 4:
        # when down sample to (4,4)  (really small), performance of resize is really bad, only 92%~93% on test
        # use old method to resize to keep 99%
        return resize_and_norm(batch["images"])
    elif config.n_qubits > 4:
        return resize_and_norm(batch["images"], config.n_qubits)
    else:
        raise ValueError("Not Support to down sample to below (4, 4)")


def train(dataloader, model, optimizer, criterion, config):
    log_path = os.path.join(config.logs, "train_loss.txt")
    with tqdm(dataloader) as bar:
        for batch in bar:
            inputs = get_inputs(batch, config).to(config.device)
            targets = batch["digits"].to(config.device)

            # calculate gradients via back propagation
            prediction = model(inputs)
            loss = criterion(prediction, targets)
            optimizer.zero_grad()
            loss.backward()

            optimizer.step()

            bar.set_postfix(loss=loss.item())

            append_log(log_path, loss.item())
            # print(f"loss: {loss.item()}", end='\r')


def valid_test(dataloader, split, model, criterion, config):
    acc_path = os.path.join(config.logs, f"{split}_acc.txt")
    loss_path = os.path.join(config.logs, f"{split}_loss.txt")
    target_all = []
    output_all = []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            targets = batch["digits"].to(config.device)
            inputs = get_inputs(batch, config).to(config.device)
            prediction = model(inputs)
            target_all.append(targets)
            output_all.append(prediction)

        target_all = torch.cat(target_all, dim=0)
        output_all = torch.cat(output_all, dim=0)

    _, indices = output_all.topk(1, dim=1)
    masks = indices.eq(target_all.view(-1, 1).expand_as(indices))
    size = target_all.shape[0]
    corrects = masks.sum().item()
    accuracy = corrects / size
    loss = criterion(output_all, target_all).item()

    print(f"{split} set accuracy: {accuracy}")
    print(f"{split} set loss: {loss}")

    append_log(acc_path, accuracy)
    append_log(loss_path, loss)


def inference(args, model, loader: DataLoader, dataset: MNIST_AAE_Dataset):
    batch_idx = 0

    for batch in tqdm(loader, leave=False):
        images = batch["images"]
        images = resize_and_norm(images, args.n_qubits).to(args.device)

        pred = model(images)

        assert len(pred) == len(batch["index"])
        for i, idx in enumerate(batch["index"]):
            p = pred[i]
            dataset.save_pred_params(idx, p)

        batch_idx += 1

    dataset.save_dataset_to_disk()


if __name__ == "__main__":
    args = parse_args()

    seed_everything(args.seed)

    # data
    train_path = os.path.join(args.data_dir, "mnist_train.pt")
    val_path = os.path.join(args.data_dir, "mnist_valid.pt")
    test_path = os.path.join(args.data_dir, "mnist_test.pt")

    train_ds = MNIST_AAE_Dataset(train_path)
    val_ds = MNIST_AAE_Dataset(val_path)
    test_ds = MNIST_AAE_Dataset(test_path)

    train_loader = DataLoader(
        train_ds,
        args.batch_size,
        True,
        num_workers=args.num_workers,
        pin_memory=args.pin_memory,
    )
    val_loader = DataLoader(
        val_ds,
        args.batch_size,
        False,
        num_workers=args.num_workers,
        pin_memory=args.pin_memory,
    )
    test_loader = DataLoader(
        test_ds,
        args.batch_size,
        False,
        num_workers=args.num_workers,
        pin_memory=args.pin_memory,
    )

    # super encoder
    super_encoder = get_super_encoder(args)

    # qlayer
    q_device = qml.device("default.qubit", wires=args.n_qubits)

    @qml.qnode(q_device, interface="torch")
    @qml.simplify
    def quantum_model(inputs, ansatz_weights):
        """Complete quantum NN model"""
        aae_encoder_for_train(inputs, args.n_encoder_layers, args.n_qubits)
        qml.StronglyEntanglingLayers(ansatz_weights, wires=range(args.n_qubits))
        return [qml.expval(qml.PauliZ(i)) for i in range(args.n_qubits)]

    weight_shapes = {"ansatz_weights": (args.n_ansatz_layers, args.n_qubits, 3)}
    qlayer = qml.qnn.TorchLayer(
        quantum_model, weight_shapes, init_method=nn.init.uniform_
    )

    model = torch.nn.Sequential(
        nn.Flatten(),
        super_encoder,
        qlayer,
        nn.LogSoftmax(dim=-1),
    ).to(args.device)

    criterion = nn.NLLLoss()
    optimizer = torch.optim.Adam(
        model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=args.n_epochs
    )

    if args.mode == "train":
        if os.path.exists(args.logs):
            import shutil

            shutil.rmtree(args.logs)
        os.makedirs(args.logs, exist_ok=True)

        # train and eval
        for epoch in range(1, args.n_epochs + 1):
            # train
            print(f"Epoch {epoch}:")
            train(train_loader, model, optimizer, criterion, args)
            # print(optimizer.param_groups[0]['lr'])

            # valid
            valid_test(val_loader, "valid", model, criterion, args)
            scheduler.step()

        valid_test(test_loader, "test", model, criterion, args)

        if args.save_path:
            os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
            torch.save(model[1], args.save_path)  # only save super encoder

    elif args.mode == "inference":
        model = torch.load(args.save_path)
        inference(args, model, train_loader, train_ds)
        inference(args, model, val_loader, val_ds)
        inference(args, model, test_loader, test_ds)
    else:
        raise ValueError("mode need to be 'train' or 'inference' ")
