#!/usr/bin/env python
# coding: utf-8

import argparse
import os
from dataclasses import dataclass

import pennylane as qml
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm import tqdm

from data_utils.aae_dataset import MNIST_AAE_Dataset
from models.batch_encoders import aae_encoder_for_train
from utils import append_log, resize_and_norm


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--mode",
        type=str,
        default="org",
        help="org (original encoder parameters) | "
        "pred (predicted encoder parameters) | "
        "e2e (end to end training both superencoder and qnn)",
    )
    parser.add_argument(
        "--logs", type=str, default="logs/qnn/aae/", help="Directory to save logs"
    )
    return parser.parse_args()


ARGS = parse_args()


@dataclass
class Config:
    data_dir: str = r"./mnist/processed/"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    n_qubits: int = 4
    n_encoder_layers: int = 8

    n_ansatz_layers: int = 5

    n_epochs: int = 10
    batch_size: int = 32
    num_workers: int = 0
    pin_memory: bool = False


config = Config()

train_path = os.path.join(config.data_dir, "mnist_train.pt")
val_path = os.path.join(config.data_dir, "mnist_valid.pt")
test_path = os.path.join(config.data_dir, "mnist_test.pt")


# ## Data
train_ds = MNIST_AAE_Dataset(train_path)
val_ds = MNIST_AAE_Dataset(val_path)
test_ds = MNIST_AAE_Dataset(test_path)

train_loader = DataLoader(
    train_ds,
    config.batch_size,
    True,
    num_workers=config.num_workers,
    pin_memory=config.pin_memory,
)
val_loader = DataLoader(
    val_ds,
    config.batch_size,
    False,
    num_workers=config.num_workers,
    pin_memory=config.pin_memory,
)
test_loader = DataLoader(
    test_ds,
    config.batch_size,
    False,
    num_workers=config.num_workers,
    pin_memory=config.pin_memory,
)


samples = next(iter(train_loader))

_, n_encoder_layers, n_qubits = samples["encoder_params"]["weights"].shape

if n_encoder_layers != config.n_encoder_layers or n_qubits != config.n_qubits:
    raise ValueError("Unmatched num_encoder_layers or num_qubits!")


q_device = qml.device("default.qubit", wires=config.n_qubits)


# test
@qml.qnode(q_device, interface="torch")
@qml.simplify
def encoder_test(inputs):
    aae_encoder_for_train(inputs, config.n_encoder_layers, config.n_qubits)
    return qml.state()


index = 0
import matplotlib.pylab as plt

fig, axes = plt.subplots(1, 2)
axes[0].imshow(resize_and_norm(samples["images"][index]).view(1, 4, 4).permute(1, 2, 0))
axes[1].imshow(
    encoder_test(samples["encoder_params"]["weights"][index].unsqueeze(0).view(1, -1))
    .type(torch.float32)
    .detach()
    .view(1, 4, 4)
    .permute(1, 2, 0)
)

samples["encoder_params"]["weights"][0]


@qml.qnode(q_device, interface="torch")
@qml.simplify
def circuit(inputs, weights):
    aae_encoder_for_train(inputs, config.n_encoder_layers, config.n_qubits)

    qml.StronglyEntanglingLayers(weights, wires=range(config.n_qubits))
    return [qml.expval(qml.PauliZ(i)) for i in range(config.n_qubits)]


qml.draw_mpl(circuit, decimals=2, expansion_strategy="device")(
    samples["encoder_params"]["weights"][0].unsqueeze(0).view(1, -1),
    torch.normal(0, 0.1, size=(config.n_ansatz_layers, config.n_qubits, 3)),
)

# params show in the figure is inconsistent with encoder_params, but above test pass
weight_shapes = {"weights": (config.n_ansatz_layers, config.n_qubits, 3)}

qlayer = qml.qnn.TorchLayer(circuit, weight_shapes, init_method=nn.init.uniform_)
# expecting (B, n_encoder_layers, n_qubits)
model = torch.nn.Sequential(nn.Flatten(), qlayer, nn.LogSoftmax(dim=-1)).to(
    config.device
)


model(samples["encoder_params"]["weights"])


# ## Train and Eval
criterion = nn.NLLLoss()
# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.n_epochs)


pred = model(samples["encoder_params"]["weights"])
print(criterion(pred, samples["digits"]))


def train(dataloader, model, optimizer):
    log_path = os.path.join(ARGS.logs, "train_loss.txt")
    with tqdm(dataloader) as bar:
        for batch in bar:
            if ARGS.mode == "org":
                inputs = batch["encoder_params"]["weights"].to(config.device)
            else:
                inputs = batch["pred_params"].to(config.device)

            targets = batch["digits"].to(config.device)

            # calculate gradients via back propagation
            prediction = model(inputs)
            loss = criterion(prediction, targets)
            optimizer.zero_grad()
            loss.backward()

            optimizer.step()

            bar.set_postfix(loss=loss.item())

            append_log(log_path, loss.item())
            # print(f"loss: {loss.item()}", end='\r')


def valid_test(dataloader, split, model):
    acc_path = os.path.join(ARGS.logs, f"{split}_acc.txt")
    loss_path = os.path.join(ARGS.logs, f"{split}_loss.txt")
    target_all = []
    output_all = []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            if ARGS.mode == "org":
                inputs = batch["encoder_params"]["weights"].to(config.device)
            else:
                inputs = batch["pred_params"].to(config.device)
            targets = batch["digits"].to(config.device)

            prediction = model(inputs)
            target_all.append(targets)
            output_all.append(prediction)

        target_all = torch.cat(target_all, dim=0)
        output_all = torch.cat(output_all, dim=0)
    #         print(output_all)

    _, indices = output_all.topk(1, dim=1)
    masks = indices.eq(target_all.view(-1, 1).expand_as(indices))
    size = target_all.shape[0]
    corrects = masks.sum().item()
    accuracy = corrects / size
    loss = criterion(output_all, target_all).item()

    print(f"{split} set accuracy: {accuracy}")
    print(f"{split} set loss: {loss}")

    append_log(acc_path, accuracy)
    append_log(loss_path, loss)


weight_shapes = {"weights": (config.n_ansatz_layers, config.n_qubits, 3)}

qlayer = qml.qnn.TorchLayer(circuit, weight_shapes, init_method=nn.init.uniform_)
# expecting (B, n_encoder_layers, n_qubits)
model = torch.nn.Sequential(nn.Flatten(), qlayer, nn.LogSoftmax(dim=-1)).to(
    config.device
)

criterion = nn.NLLLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.n_epochs)


if os.path.exists(ARGS.logs):
    os.rmdir(ARGS.logs)
os.makedirs(ARGS.logs, exist_ok=True)

for epoch in range(1, config.n_epochs + 1):
    # train
    print(f"Epoch {epoch}:")
    train(train_loader, model, optimizer)
    # print(optimizer.param_groups[0]['lr'])

    # valid
    valid_test(val_loader, "valid", model)
    scheduler.step()

# test
# valid(test_loader, 'test', model)


valid_test(test_loader, "test", model)


# ## Result

# **pennylane amplitude encoding**
#
# Epoch 1:
# 0.005 0.8584052324295044
# valid set accuracy: 0.6730290456431536
# valid set loss: 0.9551709294319153
#
# Epoch 2:
# 0.0048776412907378846457
# valid set accuracy: 0.8929460580912864
# valid set loss: 0.7974672913551331
#
# Epoch 3:
# 0.0045225424859373685027
# valid set accuracy: 0.9352697095435685
# valid set loss: 0.753818690776825
#
# Epoch 4:
# 0.0039694631307311836213
# valid set accuracy: 0.9485477178423236
# valid set loss: 0.7371468544006348
#
# Epoch 5:
# 0.0032725424859373687498
# valid set accuracy: 0.9609958506224067
# valid set loss: 0.727253794670105
#
# Epoch 6:
# 0.00250.6985626816749573
# valid set accuracy: 0.9643153526970955
# valid set loss: 0.721916913986206
#
# Epoch 7:
# 0.0017274575140626316482
# valid set accuracy: 0.970954356846473
# valid set loss: 0.7195770144462585
#
# Epoch 8:
# 0.0010305368692688174391
# valid set accuracy: 0.9701244813278008
# valid set loss: 0.7184303998947144
#
# Epoch 9:
# 0.0004774575140626316366
# valid set accuracy: 0.9734439834024896
# valid set loss: 0.7179780006408691
#
# Epoch 10:
# 0.0001223587092621161729
# valid set accuracy: 0.9742738589211618
# valid set loss: 0.7178515195846558
#
# test set accuracy: 0.9715447154471545
# test set loss: 0.7131120562553406
#

# **pennylane angle ecoding**
#
# Epoch 1:
# 0.005 1.0613005161285421
# valid set accuracy: 0.8721991701244813
# valid set loss: 0.9659674167633057
#
# Epoch 2:
# 0.0048776412907378846388
# valid set accuracy: 0.9294605809128631
# valid set loss: 0.7519037127494812
#
# Epoch 3:
# 0.0045225424859373685668
# valid set accuracy: 0.9427385892116182
# valid set loss: 0.6945566534996033
#
# Epoch 4:
# 0.0039694631307311838341
# valid set accuracy: 0.9477178423236514
# valid set loss: 0.6593191027641296
#
# Epoch 5:
# 0.0032725424859373687941
# valid set accuracy: 0.941908713692946
# valid set loss: 0.6438149809837341
#
# Epoch 6:
# 0.00250.6027146577835083
# valid set accuracy: 0.9394190871369295
# valid set loss: 0.637258768081665
#
# Epoch 7:
# 0.0017274575140626316727
# valid set accuracy: 0.9402489626556016
# valid set loss: 0.6341438889503479
#
# Epoch 8:
# 0.0010305368692688174798
# valid set accuracy: 0.9402489626556016
# valid set loss: 0.63267982006073
#
# Epoch 9:
# 0.0004774575140626316303
# valid set accuracy: 0.9410788381742738
# valid set loss: 0.6321021318435669
#
# Epoch 10:
# 0.0001223587092621161779
# valid set accuracy: 0.9410788381742738
# valid set loss: 0.6319551467895508
#
# test set accuracy: 0.9242886178861789
# test set loss: 0.639994740486145
#

# **AAE**
#
# Epoch 1:
# 100%|██████████| 157/157 [00:17<00:00,  9.23it/s, loss=0.81]
#
# 100%|██████████| 38/38 [00:03<00:00, 11.31it/s]
# valid set accuracy: 0.9518672199170124
# valid set loss: 0.806155800819397
#
# Epoch 2:
# 100%|██████████| 157/157 [00:16<00:00,  9.33it/s, loss=0.685]
#
# 100%|██████████| 38/38 [00:03<00:00, 11.88it/s]
# valid set accuracy: 0.9759336099585062
# valid set loss: 0.7286494374275208
#
# Epoch 3:
# 100%|██████████| 157/157 [00:16<00:00,  9.37it/s, loss=0.75]
#
# 100%|██████████| 38/38 [00:03<00:00, 12.11it/s]
# valid set accuracy: 0.975103734439834
# valid set loss: 0.6749855279922485
#
# Epoch 4:
# 100%|██████████| 157/157 [00:16<00:00,  9.26it/s, loss=0.632]
#
# 100%|██████████| 38/38 [00:03<00:00, 12.25it/s]
# valid set accuracy: 0.9684647302904564
# valid set loss: 0.6676006317138672
#
# Epoch 5:
# 100%|██████████| 157/157 [00:16<00:00,  9.47it/s, loss=0.824]
#
# 100%|██████████| 38/38 [00:03<00:00, 12.14it/s]
# valid set accuracy: 0.983402489626556
# valid set loss: 0.6617234349250793
#
# Epoch 6:
# 100%|██████████| 157/157 [00:16<00:00,  9.44it/s, loss=0.74]
#
# 100%|██████████| 38/38 [00:03<00:00, 12.12it/s]
# valid set accuracy: 0.9842323651452282
# valid set loss: 0.6582380533218384
#
# Epoch 7:
# 100%|██████████| 157/157 [00:16<00:00,  9.50it/s, loss=0.746]
#
# 100%|██████████| 38/38 [00:03<00:00, 12.31it/s]
# valid set accuracy: 0.979253112033195
# valid set loss: 0.6556654572486877
#
# Epoch 8:
# 100%|██████████| 157/157 [00:16<00:00,  9.44it/s, loss=0.691]
#
# 100%|██████████| 38/38 [00:03<00:00, 11.44it/s]
# valid set accuracy: 0.9784232365145228
# valid set loss: 0.6543428897857666
#
# Epoch 9:
# 100%|██████████| 157/157 [00:16<00:00,  9.43it/s, loss=0.814]
#
# 100%|██████████| 38/38 [00:03<00:00, 12.05it/s]
# valid set accuracy: 0.979253112033195
# valid set loss: 0.6536799669265747
#
# Epoch 10:
# 100%|██████████| 157/157 [00:16<00:00,  9.55it/s, loss=0.6]
#
# 100%|██████████| 38/38 [00:03<00:00, 11.86it/s]valid set accuracy: 0.979253112033195
# valid set loss: 0.6535065770149231
#
#
# **Test set result**
# 100%|██████████| 62/62 [00:05<00:00, 12.17it/s]test set accuracy: 0.9801829268292683
# test set loss: 0.6503280401229858
#
#
