#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import time
import os 
import argparse 

# PyTorch
import torch 
import torch.nn as nn 
import torch.optim as optim 
from torch.optim import lr_scheduler 
import torchvision 
from torchvision import datasets, transforms 
import torch.nn.functional as F 

# Pennylane 
import pennylane as qml 
from pennylane import numpy as np 

from tc.tc_fc import TTLinear


parser = argparse.ArgumentParser(description='Training a Dense_VQC model on the MNIST dataset')
parser.add_argument('--save_path', metavar='DIR', default='dense_vqc', help='saved model path')
parser.add_argument('--num_qubits', default=8, help='The number of qubits', type=int)
parser.add_argument('--batch_size', default=30, help='the batch size', type=int)
parser.add_argument('--num_epochs', default=20, help='The number of epochs', type=int)
parser.add_argument('--depth_vqc', default=6, help='The depth of VQC', type=int)
parser.add_argument('--lr', default=0.005, help='Learning rate', type=float)
parser.add_argument('--feat_dims', default=784, help='The dimensions of features', type=int)
parser.add_argument('--n_class', default=10, help='number of classification classes', type=int)

args = parser.parse_args()
dev = qml.device("default.qubit", wires=args.num_qubits)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = "cpu"

def RY_layer(w):
    """
    Layer of parameterized qubit rotations around the y axis.
    """
    for idx, element in enumerate(w):
        qml.RY(element, wires=idx)
        

def RX_layer(w):
    """
    Layer of parameterized qubit rotations around the x axis.
    """
    for idx, element in enumerate(w):
        qml.RX(element, wires=idx)
        

def RZ_layer(w):
    """
    Layer of parameterized qubit rotations around the z axis.
    """
    for idx, element in enumerate(w):
        qml.RZ(element, wires=idx)
    
        
        
def entangling_layer(nqubits):
    """
    Layer of CNOTs followed by another shifted layer of CNOT.
    """
    for i in range(nqubits):
        qml.CNOT(wires=[i, (i+1) % nqubits])
        
        
@qml.qnode(dev, interface="torch")
def quantum_net(q_input_features, q_weights_flat, q_depth=6, n_qubits=8):
    """
    The variational quantum circuit.
    """

    # Reshape weights
    q_weights = q_weights_flat.reshape(3, q_depth, n_qubits)

    # Embed features in the quantum node
    RY_layer(q_input_features)

    # Sequence of trainable variational layers
    for k in range(q_depth):
        entangling_layer(n_qubits)
        RX_layer(q_weights[0, k, :])
        RY_layer(q_weights[1, k, :])
        RZ_layer(q_weights[2, k, :])        

    # Expectation values in the Z basis
    exp_vals = [qml.expval(qml.PauliZ(position)) for position in range(n_qubits)]
    
    return tuple(exp_vals)


class DressedQuantumNet(nn.Module):
    """
    Torch module implementing the *dressed* quantum net. 
    """
    def __init__(self, input_dims, n_class, n_qubits, q_depth):
        """
        Definition of the *dressed* layout.
        """
        super(DressedQuantumNet, self).__init__()
        self.n_qubits = n_qubits
        self.q_depth = q_depth
        #self.pre_net = nn.Linear(input_dims, n_qubits)
        self.pre_net = TTLinear([7, 16, 7], [2, 2, 2], tt_rank=[1, 2, 2, 1])
        self.q_params = nn.Parameter(0.01 * torch.randn(q_depth * n_qubits * 3))
        self.post_net = nn.Linear(n_qubits, n_class)
        
        
    def forward(self, input_features):
        """
        Defining how tensors are supposed to move through the *dressed* quantum net.
        """
        # obtain the input features for the quantum circuit
        # by reducing the feature dimension from 512 to 4
        pre_out = self.pre_net(input_features).to(device)
        q_in = pre_out * np.pi / 2.0
        
        # Apply the quantum circuit to each element of the batch and append to q_out
        q_out = torch.Tensor(0, self.n_qubits)
        q_out = q_out.to(device)
        for elem in q_in:
            q_out_elem = quantum_net(elem, self.q_params, self.q_depth, self.n_qubits).float().unsqueeze(0)
            q_out = torch.cat((q_out.to(device), q_out_elem.to(device)))
            
        # return the two-dimensional prediction from the postprocessing layer
        return F.log_softmax(self.post_net(q_out), dim=1)


if __name__ == "__main__":
    loss_criterion = nn.CrossEntropyLoss()
    model = DressedQuantumNet(args.feat_dims, args.n_class, args.num_qubits, args.depth_vqc).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.batch_size, shuffle=True)

    model.train()
    for epoch in range(1, args.num_epochs+1):
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.view(-1, args.feat_dims)
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()
        
            if batch_idx % 10 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.item()))

        model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in test_loader:
                data = data.view(-1, args.feat_dims)
                data, target = data.to(device), target.to(device)
                output = model(data)
                test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
                pred = output.argmax(dim=1, keepdim=True)       # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()
            
        test_loss /= len(test_loader.dataset)
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
                test_loss, correct, len(test_loader.dataset),
                100. * correct / len(test_loader.dataset)))
    
    torch.save(model.state_dict(), args.save_path + "mnist_vqc.pt")






