import os
import sys

# for macOS
os.environ['KMP_DUPLICATE_LIB_OK']='True'

sys.path.insert(0, os.getcwd())
import numpy as np
import time
import matplotlib.pyplot as plt
import matplotlib
import argparse
import pickle
import json
# import cPickle as pickle
from tempfile import TemporaryFile

outfile = TemporaryFile()
font = {'weight': 'normal',
        'size': 18}
matplotlib.rc('font', **font)
matplotlib.rcParams['text.latex.preamble'] = r"\usepackage{amsmath}"

from functools import partial
from qiskit.algorithms import VQE
from qiskit.circuit.library import RealAmplitudes
from qiskit.providers.aer import StatevectorSimulator, QasmSimulator
from qiskit.utils import QuantumInstance, algorithm_globals
from qiskit.opflow.primitive_ops import PauliOp
from qiskit.quantum_info.operators import Pauli
from qiskit.opflow.gradients import NaturalGradient
from qiskit.algorithms.optimizers import GradientDescent
from qiskit import Aer
from qiskit.providers.aer.noise import NoiseModel
from qiskit.providers.fake_provider import FakeMontreal, FakeGuadalupe, FakeManila, FakeLima
from qiskit.utils.mitigation import CompleteMeasFitter

from optimizer_lib import gd_callback_all
from dmd_method import natural_grad_dmd, plot_energies, plot_energies_errorbar


def ising_transverse_field(num_qubits: int, h: float, pbc: bool = True, ):
    """
        The 1d ising model with a transverse field
        
        H = - \sum_{<i, j>} Z_i Z_j - h \sum_i X_i
        
        Args:
            num_qubits (int): number of qubits
            h (float): the intensity of the transverse field
            pbc (bool): is the periodic boundary condition used
        
        Returns:
            H (PauliSumOp): the PauliSumOp form of the hamiltonian
            gs_energy (float): ground state energy of H
    """

    #     H = PauliOp(Pauli('I' * num_qubits),0.0)
    pauli_string = 'ZZ' + 'I' * (num_qubits - 2)
    H = PauliOp(Pauli(pauli_string), -1.0)

    for i in range(1, num_qubits - 1):
        pauli_string = 'I' * i + 'ZZ' + 'I' * (num_qubits - i - 2)
        H += PauliOp(Pauli(pauli_string), -1.0)

    if pbc:
        H += PauliOp(Pauli('Z' + 'I' * (num_qubits - 2) + 'Z'), -1.0)

    for i in range(num_qubits):
        H += PauliOp(Pauli('I' * i + 'X' + 'I' * (num_qubits - i - 1)), -h)

    gs_energy = np.linalg.eigh(H.to_matrix())[0][0]
    return H, gs_energy


def set_seed(seed):
    """set random seed
    """
    print('seed random seed', seed)
    np.random.seed(seed)


def main(args):
    #################  simulation setups  #############################
    ### initialization
    seed = args.seed
    set_seed(seed)

    num_qubits = args.num_qubits
    h = args.h
    lr = args.lr
    opt_pred = args.opt_pred
    H, gs_energy = ising_transverse_field(num_qubits, h=h, pbc=args.pbc)
    np.savetxt('./data/gs_energy.dat', np.array([gs_energy]))
    ansatz = RealAmplitudes(num_qubits, reps=args.reps, entanglement=args.entanglement, insert_barriers=True)
    num_params = ansatz.num_parameters
    maxiter = args.maxiter
    opt_method = args.opt_method
    file_name = "./plot/" + args.opt_method + "_" + "neural" + str(args.neural) + "_" + "bn" + str(args.batchnorm) + "_"
    file_name += "svdenc" + str(args.svdonencoder) + "_"
    if args.debug:
        maxiter = 10
    initial_point = np.random.random(ansatz.num_parameters)
    algorithm_globals.random_seed = seed

    #################  full vqe optimization #############################
    if args.skip_vqe==0:
      if opt_method == 'natural_grad':
          intermediate_info = {
              'nfev': [],
              'parameters': [],
              'energy': [],
              'stepsize': []
          }
          gd_callback = partial(gd_callback_all, intermediate_info=intermediate_info)
          optimizer = GradientDescent(maxiter=maxiter, learning_rate=lr, callback=gd_callback)
          gradient = NaturalGradient(
              grad_method='lin_comb',
              qfi_method='lin_comb_full',
              regularization='perturb_diag',
          )
          qi = StatevectorSimulator()
          vqe = VQE(
              ansatz=ansatz,
              initial_point=initial_point,
              optimizer=optimizer,
              gradient=gradient,
              quantum_instance=qi,
              #     callback=callback,
          )

      run_time_start = time.time()
      result = vqe.compute_minimum_eigenvalue(operator=H)
      run_time_end = time.time()
      print("Time elapsed (seconds):", run_time_end - run_time_start)
      print(result)

      sim_energies_total = intermediate_info['energy']
    else:
        with open('./record/full_vqe_intermediate_info_'+str(args.opt_method)+'_n'+str(args.num_qubits)+'_h'+str(args.h)+'_shots'+str(args.shots)+'.pkl', 'rb') as f:
            intermediate_info = pickle.load(f)
            sim_energies_total = intermediate_info['energy']

    fig, (ax1) = plt.subplots(1, 1, sharex=True, figsize=(12, 8))
    if opt_method == 'natural_grad':
        ax1.plot(intermediate_info['nfev'], intermediate_info['energy'],
                 marker='.', ms=10., ls="None", color='r', )
    else:
        ax1.errorbar(
            np.arange(len(intermediate_info['energy'])),
            intermediate_info['energy'],
            yerr=intermediate_info['stddev'],
            marker='.', ms=10., ls="None", color='r',
        )
    ax1.axhline(y=gs_energy, color='k')
    plt.grid(ls="--", lw=2, alpha=0.25)
    plt.ylabel("Cost function", fontsize=12)
    plt.xlabel("Iteration", fontsize=12)
    # plt.show()
    fig.savefig(file_name + "full_vqe.png")
    fig.savefig(file_name + "full_vqe.pdf")
    plt.close()
    # np.savetxt('./data/full_vqe_energy.dat', np.array(sim_energies_total))
    with open('./data/full_vqe_energy.pkl', 'wb') as outfile:
        pickle.dump(sim_energies_total, outfile, pickle.HIGHEST_PROTOCOL)
    with open('./data/full_vqe_intermediate_info.pkl', 'wb') as outfile:
        pickle.dump(intermediate_info, outfile, pickle.HIGHEST_PROTOCOL)
    # assert False

    if not args.skip_vanilla:
        ##################  standard dmd #############################
        window_size = 1
        num_iters_sim = args.num_iters_sim  # 20 # 25
        num_iters_dmd = args.num_iters_dmd  # 20 # 25
        if args.num_pieces == 0:
          num_pieces = int(np.ceil(maxiter / (num_iters_sim + num_iters_dmd)))
        num_pieces = args.num_pieces
        dict_run_fn = {
            'natural_grad': natural_grad_dmd,
        }

        energies_pieces, optimal_vqe_start_list, intermediate_info_pieces = dict_run_fn[opt_method](H, ansatz, seed,
                                                                                                    window_size,
                                                                                                    num_iters_sim,
                                                                                                    num_iters_dmd,
                                                                                                    num_pieces, args,
                                                                                                    lr,
                                                                                                    opt_pred)
        if opt_method == 'natural_grad':
          plot_energies(energies_pieces, sim_energies_total,
                               gs_energy, maxiter, num_iters_sim, num_iters_dmd, num_pieces, opt_method,
                               optimal_vqe_start_list, opt_pred, file_name + "vanilla")
        else:
          plot_energies_errorbar(intermediate_info_pieces, intermediate_info, energies_pieces, sim_energies_total,
                               gs_energy, maxiter, num_iters_sim, num_iters_dmd, num_pieces, opt_method,
                               optimal_vqe_start_list, opt_pred, file_name + "vanilla")
        with open('./data/vanilla_energy.pkl', 'wb') as outfile:
            pickle.dump(energies_pieces, outfile, pickle.HIGHEST_PROTOCOL)
        with open('./data/vanilla_intermediate_info.pkl', 'wb') as outfile:
            pickle.dump(intermediate_info_pieces, outfile, pickle.HIGHEST_PROTOCOL)
        with open('./data/sw_optimal_list.pkl', 'wb') as outfile:
            pickle.dump(optimal_vqe_start_list, outfile, pickle.HIGHEST_PROTOCOL)
        # np.savetxt('./data/vanilla_optimal_list.dat', np.array(optimal_vqe_start_list))

    ##################  tensor dmd #############################
    window_size = args.window_size  # 12 # 20
    num_iters_sim = args.num_iters_sim_sw  # 20 # 25
    num_iters_dmd = args.num_iters_dmd_sw  # 20 # 25
    if args.num_pieces == 0:
      num_pieces = int(np.ceil(maxiter / (num_iters_sim + num_iters_dmd)))
    num_pieces = args.num_pieces

    energies_pieces, optimal_vqe_start_list, intermediate_info_pieces = dict_run_fn[opt_method](H, ansatz, seed,
                                                                                                window_size,
                                                                                                num_iters_sim,
                                                                                                num_iters_dmd,
                                                                                                num_pieces, args, lr,
                                                                                                opt_pred)
    if opt_method == 'natural_grad':
      plot_energies(energies_pieces, sim_energies_total,
                           gs_energy, maxiter, num_iters_sim, num_iters_dmd, num_pieces, opt_method,
                           optimal_vqe_start_list, opt_pred, file_name + "sw")
    else:
      plot_energies_errorbar(intermediate_info_pieces, intermediate_info, energies_pieces, sim_energies_total,
                           gs_energy, maxiter, num_iters_sim, num_iters_dmd, num_pieces, opt_method,
                           optimal_vqe_start_list, opt_pred, file_name + "sw")
    with open('./data/sw_energy.pkl', 'wb') as outfile:
        pickle.dump(energies_pieces, outfile, pickle.HIGHEST_PROTOCOL)
    with open('./data/sw_intermediate_info.pkl', 'wb') as outfile:
        pickle.dump(intermediate_info_pieces, outfile, pickle.HIGHEST_PROTOCOL)
    with open('./data/sw_optimal_list.pkl', 'wb') as outfile:
        pickle.dump(optimal_vqe_start_list, outfile, pickle.HIGHEST_PROTOCOL)
    # np.savetxt('./data/sw_optimal_list.dat', np.array(optimal_vqe_start_list))


