import argparse

import pennylane as qml
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utils.aae_dataset import MNIST_AAE_Dataset
from models.encoders import get_aae_encoder

DEVICE = torch.device("cpu")
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("dataset_path", type=str)
    parser.add_argument("--n_qubits", type=int, default=4)
    parser.add_argument("--n_encoder_layer", type=int, default=8)
    parser.add_argument("--top_type", type=list, default=None)
    parser.add_argument("-lr", "--learning_rate", type=int, default=3e-2)
    parser.add_argument("--n_steps", type=int, default=100)
    parser.add_argument(
        "--loss_fn", type=str, choices=["tqLoss", "MSELoss"], default="tqLoss"
    )
    parser.add_argument(
        "--save-params", type=int, default=1, help="Whether to save parameters"
    )

    args = parser.parse_args()

    aae_encoder = get_aae_encoder(args.n_qubits)

    def tqLoss(result_state, target_state):
        return (
            1 - torch.dot(result_state, target_state).abs() ** 2
        )  # result_state.norm()==1 so dot() is fine

    def resize_and_norm(image):
        image = F.avg_pool2d(
            image, kernel_size=(7, 7)
        )  # (1, 1, 28, 28) -> (1, 1, 4, 4)
        image = image.view(-1)  # (1, 1, 4, 4) -> (16, )
        image = image / image.norm(
            2
        )  # normalize the image to norm(2) == 1 as quamtum state did
        return image

    def train_encoder(sample, encoder, criterion, optimizer, n_step=100, verbose=True):
        image = sample[
            "images"
        ]  # assume only one image in this sample, image.shape==(1,C,H,W)
        image = resize_and_norm(image)
        image = image.to(device=DEVICE)
        _ = torch.zeros((1,), dtype=torch.float32).to(
            device=DEVICE
        )  # inputs doesn't matter, But TorchLayer need it

        # scheduler = CosineAnnealingLR(optimizer, T_max=n_step)

        for t in range(n_step):
            encoding = encoder(_)
            loss = criterion(encoding, image)
            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
            # scheduler.step()

            if verbose:
                if t % 10 == 0:
                    print(f"loss: {loss.item()}", end="\r")

    def train_encoders_for_dataset(dataset, criterion, n_step=100):
        loader = DataLoader(dataset, 1, False)

        weight_shapes = {"weights": (args.n_encoder_layer, args.n_qubits)}

        for sample in tqdm(loader, leave=False):
            encoder = qml.qnn.TorchLayer(
                aae_encoder, weight_shapes, init_method=nn.init.uniform_
            )
            optimizer = torch.optim.Adam(encoder.parameters(), lr=args.learning_rate)
            train_encoder(sample, encoder, criterion, optimizer, n_step, verbose=True)

            sample_index = sample["index"].item()
            dataset.save_encoder_params(sample_index, encoder)

        if args.save_params:
            dataset.save_dataset_to_disk()
            print("Encoder params saved to disk")

        return dataset

    # setup training

    dataset = MNIST_AAE_Dataset(args.dataset_path)

    if args.loss_fn == tqLoss.__name__:
        criterion = tqLoss
    elif args.loss_fn == nn.MSELoss.__name__:
        criterion = nn.MSELoss()
    else:
        raise ValueError("--loss_fn must be one of ['tqLoss', 'MSELoss']")

    train_encoders_for_dataset(dataset, criterion, args.n_steps)
