# Necessary imports

import numpy as np
import matplotlib.pyplot as plt
import argparse
import pickle
import json
# import cPickle as pickle
import os
import sys
from tempfile import TemporaryFile
outfile = TemporaryFile()
from functools import partial

from qiskit import Aer, QuantumCircuit
from qiskit.utils import QuantumInstance, algorithm_globals
from qiskit.opflow import AerPauliExpectation, PauliSumOp, PauliOp
from qiskit.quantum_info import SparsePauliOp
from qiskit.circuit import Parameter
from qiskit.circuit.library import RealAmplitudes, ZZFeatureMap
from qiskit_machine_learning.neural_networks import CircuitQNN, TwoLayerQNN
from qiskit_machine_learning.connectors import TorchConnector

from torch import Tensor
from torch.nn import Linear, CrossEntropyLoss, MSELoss
from torch.optim import LBFGS
# Additional torch-related imports
import torch
from torch import cat, no_grad, manual_seed, nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.optim as optim
from torch.nn import (
    Module,
    Conv2d,
    Linear,
    Dropout2d,
    NLLLoss,
    CrossEntropyLoss,
    MaxPool2d,
    Flatten,
    Sequential,
    ReLU,
)
import torch.nn.functional as F

from optimizer_lib import gd_callback_all, test_accuracy, train_model, evaluate_model, plot_evaluate_model


class Dict2Class(object):
      
    def __init__(self, my_dict):
          
        for key in my_dict:
            setattr(self, key, my_dict[key])

def concatenate_array_list(array_list):
  aa = array_list[0]
  for i in range(1,len(array_list)):
    aa = np.concatenate((aa,array_list[i]))

  return aa

# parameter setup
seed=42
batch_size = int(sys.argv[1]) # 1
n_train_samples = int(sys.argv[2])  # 100 
n_test_samples = int(sys.argv[3]) # 50
epochs = int(sys.argv[4])  # Set number of epochs, 5
num_pieces = int(sys.argv[5]) 
num_iters_sim = int(sys.argv[6]) 
num_iters_dmd = int(sys.argv[7]) 
window_size = int(sys.argv[8]) 
neural = int(sys.argv[9]) 
cnn = int(sys.argv[10]) 
full_train=0
args = {}
args["neural"]=neural
args["cnn"]=cnn
args["neural_steps"]=30000
args["batchnorm"]=0
args["svdonencoder"]=0
args["transformer"]=0
args = Dict2Class(args)

if args.cnn==1:
  from dmd_method2 import natural_grad_dmd, qml_dmd, plot_energies, plot_energies_errorbar, params_to_array, array_to_dict, count_params_trainable, count_params 
else:
  from dmd_method import natural_grad_dmd, qml_dmd, plot_energies, plot_energies_errorbar, params_to_array, array_to_dict, count_params_trainable, count_params 


# declare quantum instance
manual_seed(seed)
algorithm_globals.random_seed = seed
qi = QuantumInstance(Aer.get_backend("aer_simulator_statevector"))

######### Train Dataset #############
# Set train shuffle seed (for reproducibility)

### train data
# Use pre-defined torchvision function to load MNIST train data
X_train = datasets.MNIST(
    root="../../", train=True, download=True, transform=transforms.Compose([transforms.Resize(14), transforms.ToTensor()])
)
# Filter out labels (originally 0-9), leaving only labels 0 and 1
idx = np.append(
    np.where(X_train.targets == 0)[0][:n_train_samples], np.where(X_train.targets == 1)[0][:n_train_samples]
)
X_train.data = X_train.data[idx]
X_train.targets = X_train.targets[idx]
# Define torch dataloader with filtered data
train_loader = DataLoader(X_train, batch_size=batch_size, shuffle=True)

### test data
# Use pre-defined torchvision function to load MNIST test data
X_test = datasets.MNIST(
    root="../../", train=False, download=True, transform=transforms.Compose([transforms.Resize(14), transforms.ToTensor()])
)
# Filter out labels (originally 0-9), leaving only labels 0 and 1
idx = np.append(
    np.where(X_test.targets == 0)[0][:n_test_samples], np.where(X_test.targets == 1)[0][:n_test_samples]
)
X_test.data = X_test.data[idx]
X_test.targets = X_test.targets[idx]
# Define torch dataloader with filtered data
test_loader = DataLoader(X_test, batch_size=batch_size, shuffle=True)

### plotshow data
#n_samples_show = 6
#data_iter = iter(train_loader)
#fig, axes = plt.subplots(nrows=1, ncols=n_samples_show, figsize=(10, 3))
#while n_samples_show > 0:
#    images, targets = data_iter.__next__()
#
#    axes[n_samples_show - 1].imshow(images[0, 0].numpy().squeeze(), cmap="gray")
#    axes[n_samples_show - 1].set_xticks([])
#    axes[n_samples_show - 1].set_yticks([])
#    axes[n_samples_show - 1].set_title("Labeled: {}".format(targets[0].item()))
#    print(images[0,0].numpy().shape)
#
#    n_samples_show -= 1

# Test Dataset
# -------------

# Set test shuffle seed (for reproducibility)
#manual_seed(seed)


######### Define Hybrid Circuit #############
# Define and create QNN
#ob = PauliSumOp.from_list([("ZI",1),("IZ",1)])
#ob2 = SparsePauliOp(["ZI","IZ"],coeffs=[1+0j,1+0j])
#ob3 = SparsePauliOp.from_sparse_list([("ZX", [0, 1], 1), ("YY", [0, 1], 2)], num_qubits=2)
#assert False
def create_qnn():
    feature_map = ZZFeatureMap(2,reps=1)
    ansatz = RealAmplitudes(2, reps=1)
    # REMEMBER TO SET input_gradients=True FOR ENABLING HYBRID GRADIENT BACKPROP
    qnn = TwoLayerQNN(
        2,
        feature_map,
        ansatz,
        input_gradients=True,
        exp_val=AerPauliExpectation(),
        #observable = ob2,
        quantum_instance=qi,
    )
    return qnn

qnn8 = create_qnn()
print(qnn8.operator)
#classifier = ZZFeatureMap(3) + EfficientSU2(3)

# Define torch NN module
class Net(Module):
    def __init__(self, qnn):
        super().__init__()
        self.conv1 = Conv2d(1, 2, kernel_size=5)
        self.conv2 = Conv2d(2, 2, kernel_size=5) #(2,4)
        self.dropout = Dropout2d()
        self.fc1 = Linear(2, 2) # (4,4)
        self.fc2 = Linear(2, 2)  # 2-dimensional input to QNN (4,2)
        self.qnn = TorchConnector(qnn)  # Apply torch connector, weights chosen
        # uniformly at random from interval [-1,1].
        #self.fc3 = Linear(1, 1)  # 1-dimensional output from QNN; Linear(1, n_classes), return x

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        #x = F.max_pool2d(x, 2)
        #x = self.dropout(x)
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = self.qnn(x)  # apply QNN
        #x = x*0.5+0.5
        #x = self.fc3(x)
        #x = cat((x, 1 - x), -1)
        #print(x.shape)
        return cat((x, 1 - x), -1)
        #return x

model_vanilla = Net(qnn8)
print("number of parameters: ", count_params_trainable(model_vanilla))
#pp = list(model_vanilla.parameters())
#print(len(pp))
#assert False

key_list, shape_list, nsize_list, params_arr = params_to_array(model_vanilla)
dd = array_to_dict(key_list, shape_list, nsize_list, params_arr)
for k,v in model_vanilla.state_dict().items():
  print(torch.linalg.norm(dd[k]-v))
print("number of params from dict: ", params_arr.shape)

#dd['conv1.weight'] *= 2
#model_vanilla.load_state_dict(dd)
#assert False

######### Training #############
# Define model, optimizer, and loss function
optimizer = optim.Adam(model_vanilla.parameters(), lr=0.0001)
loss_func = nn.CrossEntropyLoss()
#loss_func = NLLLoss()


if full_train==1:
  train_loss_arr, test_loss_arr, params_arr_arr, model_vanilla = train_model(model_vanilla, optimizer, loss_func, epochs, train_loader, test_loader, batch_size)
else:
  train_loss_pieces, test_loss_pieces, params_pieces, optimal_vqe_start_list, model_vanilla = qml_dmd(model_vanilla, optimizer, loss_func, train_model, train_loader, test_loader, batch_size, window_size, num_iters_sim, num_iters_dmd, num_pieces, args, seed, opt_pred=True)
  with open('./data/train_loss_pieces.pkl', 'wb') as outfile:
      pickle.dump(train_loss_pieces, outfile, pickle.HIGHEST_PROTOCOL)
  with open('./data/test_loss_pieces.pkl', 'wb') as outfile:
      pickle.dump(test_loss_pieces, outfile, pickle.HIGHEST_PROTOCOL)
  with open('./data/params_pieces.pkl', 'wb') as outfile:
      pickle.dump(optimal_vqe_start_list, outfile, pickle.HIGHEST_PROTOCOL)
  with open('./data/optimal_vqe_start_list.pkl', 'wb') as outfile:
      pickle.dump(params_pieces, outfile, pickle.HIGHEST_PROTOCOL)

  train_loss_arr = concatenate_array_list(train_loss_pieces)
  test_loss_arr = concatenate_array_list(test_loss_pieces)

np.savetxt("./data/train_loss_arr.dat", train_loss_arr)
np.savetxt("./data/test_loss_arr.dat", test_loss_arr)
#np.savetxt("./data/params_arr_arr.dat", params_arr_arr)
torch.save(model_vanilla.state_dict(), "./model/model_vanilla.pt")

# Plot loss convergence
plt.plot(train_loss_arr)
plt.title("Hybrid NN Training Convergence")
plt.xlabel("Training Iterations")
plt.ylabel("Neg. Log Likelihood Loss")
#plt.show()
plt.tight_layout()
plt.savefig("./plot/training_loss.png")
plt.close()

# Plot loss convergence
plt.plot(test_loss_arr[:,0])
plt.title("Hybrid NN Training Loss")
plt.xlabel("Training Iterations")
plt.ylabel("Neg. Log Likelihood Loss")
#plt.show()
plt.tight_layout()
plt.savefig("./plot/training_test_loss.png")
plt.close()

# Plot loss convergence
plt.plot(test_loss_arr[:,1])
plt.title("Hybrid NN Training Accuracy")
plt.xlabel("Training Iterations")
plt.ylabel("Neg. Log Likelihood Loss")
#plt.show()
plt.tight_layout()
plt.savefig("./plot/training_test_accuracy.png")
plt.close()

plot_evaluate_model(model_vanilla, test_loader)

