
import os
import sys

sys.path.insert(0, os.getcwd())
import numpy as np
import time
import matplotlib.pyplot as plt
import matplotlib
import copy

from functools import partial
from qiskit.algorithms import VQE
from qiskit.providers.aer import StatevectorSimulator, QasmSimulator
from qiskit.utils import QuantumInstance, algorithm_globals
from qiskit.opflow.gradients import NaturalGradient
from qiskit.algorithms.optimizers import GradientDescent
from qiskit import Aer
from qiskit.providers.aer.noise import NoiseModel
from qiskit.providers.fake_provider import FakeMontreal, FakeGuadalupe, FakeManila, FakeLima
from qiskit.utils.mitigation import CompleteMeasFitter

import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear, CrossEntropyLoss, MSELoss
from torch.optim import LBFGS
from torch import cat, no_grad, manual_seed
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.optim as optim
from torch.nn import (
    Module,
    Conv2d,
    Linear,
    Dropout2d,
    NLLLoss,
    MaxPool2d,
    Flatten,
    Sequential,
    ReLU,
)
from optimizer_lib import gd_callback_all, vqe_callback_all, test_accuracy, train_model, evaluate_model, params_to_array, array_to_dict, count_params, count_params_trainable


font = {'weight': 'normal',
        'size': 18}
matplotlib.rc('font', **font)
matplotlib.rcParams['text.latex.preamble'] = r"\usepackage{amsmath}"

torch.manual_seed(0)


def adjust_learning_rate(max_steps, warmup_fraction, base_lr, optimizer, step):
    warmup_steps = warmup_fraction * max_steps
    if step < warmup_steps:
        lr = base_lr * step / warmup_steps
    else:
        step -= warmup_steps
        max_steps -= warmup_steps
        q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
        end_lr = 0
        lr = base_lr * q + end_lr * (1 - q)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr


class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, args, depth=2, batch_norm=False):
        super().__init__()
        self.layers = nn.ModuleList([nn.Linear(input_dim, hidden_dim)])
        for _ in range(1, depth - 1):
            self.layers.append(nn.Linear(hidden_dim, hidden_dim))
        self.layers.append(nn.Linear(hidden_dim, input_dim))
        if batch_norm:
            self.bn = nn.ModuleList([nn.BatchNorm1d(hidden_dim) for _ in range(depth - 1)])
            self.bn.append(nn.BatchNorm1d(input_dim))

        self.depth = depth
        self.batch_norm = batch_norm
        self.koopman = nn.Linear(input_dim, output_dim, bias=False)

    def forward(self, x, koopman=True):
        tmp = x
        for i in range(self.depth):
            tmp = self.layers[i](tmp)
            if self.batch_norm:
                tmp = self.bn[i](tmp)
            tmp = F.elu(tmp)

        x = tmp + x

        if koopman:
            return self.koopman(x)
        else:
            return x


def sliding_window_dmd_nonsq(params, window_size: int, start_time: int, end_time: int, pred_time_len: int, neural: bool,
                             args, expansion_factor=1, lr=0.001, batchnorm=False, name=0):
    """
        Args:
            params (numpy array): shape (num_iters, num_params)
            window_size (int): the size of the sliding window in optimization iter time
            start_time (int): (inclusive)
            end_time (int): (inclusive)
            pred_time_len (int):
        Returns:
            X_pred: 
    """
    params_in_window = []
    num_iters, num_params = params.shape
    for t in range(num_iters - window_size + 1):
        params_in_window.append(np.concatenate(([params[t + _, :] for _ in range(window_size)])))

    params_in_window = np.array(params_in_window).transpose()

    X = params_in_window[:, start_time:end_time]
    X_prime = params.transpose()[:, start_time + window_size:end_time + window_size]

    if not neural:
        A = X_prime @ np.linalg.pinv(X)

        #X_last = X[:, -1].reshape((-1, 1))
        X_last = params_in_window[:, end_time].reshape((-1, 1))
        X_pred = [A @ X_last]

        for t in range(end_time, end_time + pred_time_len - 1):
            X_last = np.concatenate((X_last[num_params:, :], X_pred[-1]), axis=0)
            X_pred.append(A @ X_last)

        X_pred = np.array(X_pred)

        X_pred = X_pred.reshape((X_pred.shape[0], X_pred.shape[1])).transpose()
    else:
        X_t = torch.from_numpy(X).float().T  # num_iters x num_params
        X_prime_t = torch.from_numpy(X_prime).float().T  # num_iters x num_params
        net = Encoder(X_t.shape[1], X_t.shape[1] * expansion_factor, num_params, args=args, batch_norm=batchnorm)
        net.train()

        optimizer = torch.optim.Adam(net.parameters(), lr=lr)
        losses = []
        for i in range(args.neural_steps):
            lr2 = adjust_learning_rate(args.neural_steps, 0.3, lr, optimizer, i)
            prediction = net.forward(X_t, koopman=True)
            loss = F.mse_loss(prediction, X_prime_t)

            # with torch.no_grad():
            #     pred = net.forward(X, koopman=False)
            #     sol = torch.linalg.lstsq(pred, X_prime_t).solution  # doing the SVD to find the best fit to L_2 norm
            #     print('SVD loss', F.mse_loss(pred @ sol, X_prime_t).item(), 'SGD loss', loss.item(),
            #           'SGD learned', F.mse_loss(pred @ net.koopman.weight.T, X_prime_t).item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #print(f'loss: {loss.item():.12f} | lr: {lr2:.12f}')
            losses.append(loss.item())
        print(f'loss: {loss.item():.12f} | lr: {lr2:.12f}')
        np.savetxt('./data/losses_'+str(window_size)+'_'+str(name)+'.dat', np.array([loss.item()]))

        plt.switch_backend('agg')
        plt.plot(losses)
        plt.yscale('log')
        plt.savefig(f'./plot/losses_{window_size}_{name}.png')
        plt.close()

        # evaluation
        net.eval()
        if not args.svdonencoder:
            predictions = []
            last = copy.deepcopy(torch.from_numpy(params_in_window[:, start_time+1:end_time+1]).float().T[-1,:].unsqueeze(0))
            #last = copy.deepcopy(X_t[-1, :].unsqueeze(0))
            with torch.no_grad():
                predictions.append(net(last))
                for _ in range(pred_time_len - 1):
                    last = torch.cat([last[:, num_params:], predictions[-1]], dim=1)
                    predictions.append(net(last))
            X_pred = torch.cat(predictions, dim=0).T.detach().numpy()
        else:
            with torch.no_grad():
                pred = net.forward(X_t, koopman=False)

            pred_np = pred.T.detach().cpu().numpy()
            A = X_prime @ np.linalg.pinv(pred_np)

            X_last = X[:, -1].reshape((-1, 1))
            print('Xlast shape', X_last.shape)
            X_pred = [A @ net.forward(torch.from_numpy(X_last).float().T, koopman=False).T.detach().cpu().numpy()]

            for t in range(end_time, end_time + pred_time_len - 1):
                X_last = np.concatenate((X_last[num_params:, :], X_pred[-1]), axis=0)
                X_pred.append(
                    A @ net.forward(torch.from_numpy(X_last).float().T, koopman=False).T.detach().cpu().numpy())

            X_pred = np.array(X_pred)

            X_pred = X_pred.reshape((X_pred.shape[0], X_pred.shape[1])).transpose()

            # with torch.no_grad():
            #     pred = net.forward(X, koopman=False)
            #     sol = torch.linalg.lstsq(pred, X_prime_t).solution  # doing the SVD to find the best fit to L_2 norm
            #     # print('SVD loss', F.mse_loss(pred @ sol, X_prime_t).item(), 'SGD loss', loss.item(),
            #     #       'SGD learned', F.mse_loss(pred @ net.koopman.weight.T, X_prime_t).item())
            #
            # predictions = []
            # last = copy.deepcopy(X[-1, :].unsqueeze(0))
            # with torch.no_grad():
            #     predictions.append(net.forward(last, koopman=False) @ sol)
            #     for _ in range(pred_time_len - 1):
            #         last = torch.cat([last[:, num_params:], predictions[-1]], dim=1)
            #         predictions.append(net.forward(last, koopman=False) @ sol)
            # X_pred = torch.cat(predictions, dim=0).T.detach().numpy()
    return X_pred


def qml_dmd(model, optimizer, loss_func, train_model, train_loader, test_loader, batch_size, window_size, num_iters_sim, num_iters_dmd, num_pieces, args, seed, opt_pred=True):
    """dmd with qml

    Args:
      H: hamiltonian
      ansatz: quantum circuit
      optimizer: optimizer for circuit
      maxiter: maximum iterations
      seed: random seed
      window_size (int): window size for dmd; 1, 12, 20
      num_iters_sim (int): number of piecewise simulation; 20, 25
      num_iters_dmd (int): number of piecewise dmd; int, 20, 25
      num_pieces (int): number of pieces of calling dmd
      neural: use neural network or not
      lr: learning rate
      opt_pred: optimal prediction of params during prediction 

    Returns:
    """

    # piecewise acceleration
    run_time_start = time.time()
    train_loss_pieces = []
    test_loss_pieces = []
    params_pieces = []
    optimal_vqe_start_list = []
    num_params = count_params_trainable(model)
    if num_params != count_params(model):
      assert False, "trainable params not match!"

    for i in range(num_pieces):
        if i == 0:
            seed = seed
            np.random.seed(seed)

            key_list, shape_list, nsize_list, params_arr = params_to_array(model)
            init_params = params_arr
        else:
            if opt_pred:
                init_params = X_pred[-num_params:, optimal_vqe_start]
            else:
                init_params = X_pred[-num_params:, -1]

        algorithm_globals.random_seed = seed


        train_loss_arr, test_loss_arr, params_arr_arr, model = train_model(model, optimizer, loss_func, num_iters_sim, train_loader, test_loader, batch_size)
        train_loss_pieces.append(train_loss_arr)
        test_loss_pieces.append(test_loss_arr)
        params_pieces.append(params_arr_arr)

        X_pred = sliding_window_dmd_nonsq(
            params_arr_arr,
            window_size=window_size,
            start_time=0,
            end_time=num_iters_sim - window_size,
            pred_time_len=num_iters_dmd,
            neural=args.neural,
            args=args,
            batchnorm=args.batchnorm,
            name=i
        )

        params_pieces.append(X_pred)

        dmd_train_loss = []
        dmd_test_loss = []
        for t in range(num_iters_dmd):
            params = X_pred[-num_params:, t]
            pred_dict = array_to_dict(key_list, shape_list, nsize_list, params)
            model.load_state_dict(pred_dict)
            train_loss, test_loss_list, model = evaluate_model(model, loss_func, train_loader, test_loader, batch_size)
            dmd_train_loss.append(train_loss)
            dmd_test_loss.append(test_loss_list)

        optimal_vqe_start = np.argmin(dmd_train_loss)
        optimal_vqe_start_list.append(optimal_vqe_start)
        train_loss_pieces.append(np.array(dmd_train_loss))
        test_loss_pieces.append(np.array(dmd_test_loss))

    run_time_end = time.time()

    print("Time elapsed (seconds):", run_time_end - run_time_start)

    return train_loss_pieces, test_loss_pieces, params_pieces, optimal_vqe_start_list, model


def natural_grad_dmd(H, ansatz, seed, window_size, num_iters_sim, num_iters_dmd, num_pieces, args, lr, opt_pred=True):
    """dmd with natural gradient

    Args:
      H: hamiltonian
      ansatz: quantum circuit
      optimizer: optimizer for circuit
      maxiter: maximum iterations
      seed: random seed
      window_size (int): window size for dmd; 1, 12, 20
      num_iters_sim (int): number of piecewise simulation; 20, 25
      num_iters_dmd (int): number of piecewise dmd; int, 20, 25
      num_pieces (int): number of pieces of calling dmd
      neural: use neural network or not
      lr: learning rate

    Returns:
    """

    run_time_start = time.time()
    energies_pieces = []
    params_pieces = []
    optimal_vqe_start_list = []
    intermediate_info_pieces = []
    num_params = ansatz.num_parameters

    for i in range(num_pieces):
        # for i in range(1):
        if i == 0:
            seed = seed
            np.random.seed(seed)

            initial_point = np.random.random(ansatz.num_parameters)
            init_params = initial_point
        else:
            init_params = X_pred[-num_params:, -1]
            # if opt_pred:
            #  init_params = X_pred[-num_params:, optimal_vqe_start]
            # else:
            #  init_params = X_pred[-num_params:, -1]

        algorithm_globals.random_seed = seed

        gd_intermediate_info = {
            'nfev': [],
            'parameters': [],
            'energy': [],
            'stepsize': []
        }
        gd_callback = partial(gd_callback_all, intermediate_info=gd_intermediate_info)
        optimizer = GradientDescent(maxiter=num_iters_sim, learning_rate=lr, callback=gd_callback)
        qi = StatevectorSimulator()
        gradient = NaturalGradient(
            grad_method='lin_comb',
            qfi_method='lin_comb_full',
            regularization='perturb_diag',
        )
        vqe = VQE(
            ansatz=ansatz,
            initial_point=init_params,
            optimizer=optimizer,
            gradient=gradient,
            quantum_instance=qi,
            #         callback=callback,
        )

        result = vqe.compute_minimum_eigenvalue(operator=H)
        intermediate_info_pieces.append(gd_intermediate_info)

        sim_energies = gd_intermediate_info['energy']
        sim_params = gd_intermediate_info['parameters']

        sim_energies = np.array(sim_energies)
        sim_params = np.array(sim_params)
        energies_pieces.append(sim_energies)
        params_pieces.append(sim_params)

        X_pred = sliding_window_dmd_nonsq(
            sim_params,
            window_size=window_size,
            start_time=0,
            end_time=num_iters_sim - window_size,
            pred_time_len=num_iters_dmd,
            neural=args.neural,
            args=args,
            batchnorm=args.batchnorm,
            name=i
        )

        params_pieces.append(X_pred[-num_params:, :])
        dmd_energies = []
        intermediate_info = {
            'nfev': [],
            'parameters': [],
            'energy': [],
            'stddev': []
        }
        callback = partial(vqe_callback_all, intermediate_info=intermediate_info)

        for t in range(num_iters_dmd):
            vqe_single = VQE(
                ansatz=ansatz,
                quantum_instance=qi,
                callback=callback,
            )
            params = X_pred[-num_params:, t]
            dmd_energies.append(vqe_single.get_energy_evaluation(operator=H)(params))

        optimal_vqe_start = np.argmin(dmd_energies)
        optimal_vqe_start_list.append(optimal_vqe_start)
        energies_pieces.append(np.array(dmd_energies))
        intermediate_info_pieces.append(intermediate_info)

    run_time_end = time.time()

    print("Time elapsed (seconds):", run_time_end - run_time_start)

    return energies_pieces, optimal_vqe_start_list, intermediate_info_pieces


def plot_energies(energies_pieces, sim_energies_total, gs_energy, maxiter, num_iters_sim, num_iters_dmd, num_pieces,
                  opt_method, optimal_vqe_start_list, opt_pred, file_name):
    """plot dmd and full simulated energy comparisons

    Args:
      energies_pieces: dmd energies
      sim_energies_total: full simulation energies
      gs_energy: true ground state energy
      maxiter: maximum iterations
      num_iters_sim (int): number of piecewise simulation; 20, 25
      num_iters_dmd (int): number of piecewise dmd; int, 20, 25
      num_pieces (int): number of pieces of calling dmd

    Returns:
    """

    t_start = 0
    t_end = 0
    for i in range(num_pieces):
        t_start = t_end
        t_end = t_start + num_iters_sim
        if i == 0:
            plt.plot(np.arange(t_start, t_end), energies_pieces[2 * i], color='b', label='Sim (Piecewise)')
        else:
            plt.plot(np.arange(t_start, t_end), energies_pieces[2 * i], color='b')

        t_start = t_end
        t_end = t_start + num_iters_dmd

        if i == 0:
            if opt_pred:
                plt.plot(np.arange(t_start, t_start + optimal_vqe_start_list[i] + 1),
                         energies_pieces[2 * i + 1][:optimal_vqe_start_list[i] + 1], color='r', label='DMD (Piecewise)')
                plt.plot(np.arange(t_start + optimal_vqe_start_list[i] + 1, t_end),
                         energies_pieces[2 * i + 1][optimal_vqe_start_list[i] + 1:], ls='--', color='r')
            else:
                plt.plot(np.arange(t_start, t_end), energies_pieces[2 * i + 1], color='r', label='DMD (Piecewise)')
        else:
            if opt_pred:
                plt.plot(np.arange(t_start, t_start + optimal_vqe_start_list[i] + 1),
                         energies_pieces[2 * i + 1][:optimal_vqe_start_list[i] + 1], color='r')
                plt.plot(np.arange(t_start + optimal_vqe_start_list[i] + 1, t_end),
                         energies_pieces[2 * i + 1][optimal_vqe_start_list[i] + 1:], ls='--', color='r')
            else:
                plt.plot(np.arange(t_start, t_end), energies_pieces[2 * i + 1], color='r')

    #elif opt_method == 'natural_grad':
    #    plt.plot(np.arange(0, maxiter), sim_energies_total, color='k', label='Sim (Total)')
    plt.plot(np.arange(0, len(sim_energies_total)), sim_energies_total, color='k', label='Sim (Total)')
    plt.axhline(y=gs_energy, color='y', label='Real')

    plt.ylabel('Eigenvalue')
    plt.xlabel('Iterations')

    plt.legend()
    plt.tight_layout()
    plt.savefig(file_name + ".png")
    plt.savefig(file_name + ".pdf")
    # plt.show()
    plt.close()

    return None


def plot_energies_errorbar(intermediate_info_pieces, intermediate_info_vqe_total, energies_pieces, sim_energies_total,
                           gs_energy, maxiter, num_iters_sim, num_iters_dmd, num_pieces, opt_method,
                           optimal_vqe_start_list, opt_pred, file_name):
    """plot dmd and full simulated energy comparisons

    Args:
      energies_pieces: dmd energies
      sim_energies_total: full simulation energies
      gs_energy: true ground state energy
      maxiter: maximum iterations
      num_iters_sim (int): number of piecewise simulation; 20, 25
      num_iters_dmd (int): number of piecewise dmd; int, 20, 25
      num_pieces (int): number of pieces of calling dmd

    Returns:
    """

    t_start = 0
    t_end = 0

    for i in range(num_pieces):
        t_start = t_end
        t_end = t_start + num_iters_sim
        if i == 0:
            plt.errorbar(
                np.arange(t_start, t_end),
                intermediate_info_pieces[2 * i]['energy'],
                yerr=intermediate_info_pieces[2 * i]['stddev'],
                color='b',
                label='Sim (Piecewise)',
            )
        else:
            plt.errorbar(
                np.arange(t_start, t_end),
                intermediate_info_pieces[2 * i]['energy'],
                yerr=intermediate_info_pieces[2 * i]['stddev'],
                color='b',
            )

        t_start = t_end
        t_end = t_start + num_iters_dmd

        optimal_start = optimal_vqe_start_list[i]

        if i == 0:
            plt.errorbar(
                np.arange(t_start, t_start + optimal_vqe_start_list[i] + 1),
                intermediate_info_pieces[2 * i + 1]['energy'][:optimal_start + 1],
                intermediate_info_pieces[2 * i + 1]['stddev'][:optimal_start + 1],
                color='r',
                label='DMD (Piecewise)',
            )
            plt.errorbar(
                np.arange(t_start + optimal_vqe_start_list[i] + 1, t_end),
                intermediate_info_pieces[2 * i + 1]['energy'][optimal_start + 1:],
                yerr=intermediate_info_pieces[2 * i + 1]['stddev'][optimal_start + 1:],
                ls='--',
                color='r',
            )
        else:
            plt.errorbar(
                np.arange(t_start, t_start + optimal_vqe_start_list[i] + 1),
                intermediate_info_pieces[2 * i + 1]['energy'][:optimal_start + 1],
                yerr=intermediate_info_pieces[2 * i + 1]['stddev'][:optimal_start + 1],
                color='r',
            )
            plt.errorbar(
                np.arange(t_start + optimal_vqe_start_list[i] + 1, t_end),
                intermediate_info_pieces[2 * i + 1]['energy'][optimal_start + 1:],
                yerr=intermediate_info_pieces[2 * i + 1]['stddev'][optimal_start + 1:],
                ls='--',
                color='r',
            )

    #     plt.plot(np.arange(0, maxiter + 1), sim_energies_total, color='k', label='Sim (Total)')
    plt.errorbar(
        np.arange(len(intermediate_info_vqe_total['energy'])),
        intermediate_info_vqe_total['energy'],
        yerr=intermediate_info_vqe_total['stddev'],
        color='k',
        label='Sim (Total)',
    )

    plt.axhline(y=gs_energy, color='y', label='Real')

    plt.ylabel('Eigenvalue')
    plt.xlabel('Iterations')
    # plt.ylim((-13, -9))
    plt.legend()
    plt.tight_layout()
    plt.savefig(file_name + ".png")
    plt.savefig(file_name + ".pdf")

    # plt.show()
    plt.close()


