from typing import Dict, Optional, Tuple
import time
import logging
import numpy as np
from project_qsl import QCircuit, GHZSimCircuit, minimize_scipyopt
from distance import distance2norm_scipyopt, fidelity
from qiskit_aer import AerSimulator

__all__ = ["quantum_state_learning"]


def single_step_optimization(
        step_num: int,
        num_qubits: int,
        prev_params: np.ndarray,
        new_init_params: np.ndarray,
        num_shots: int = 1024,
        qiskit_backend: Optional[AerSimulator] = None,
        optim_kwargs: Dict = dict()
) -> Tuple[float, QCircuit, list]:

    cir = GHZSimCircuit(num_qubits)
    for _ in range(step_num - 1):
        cir.add_block()
    cir.set_all_params(prev_params)
    cir.add_block()
    loss_v, cir, _, loss_history = minimize_scipyopt(
        distance2norm_scipyopt, new_init_params, cir, shots=num_shots, qiskit_backend=qiskit_backend,
        error_mitigator=None, method="COBYLA", **optim_kwargs
    )
    return loss_v, cir, loss_history


def quantum_state_learning(
    init_params: np.ndarray,
    num_qubits: int,
    outputs_dirname: str,
    num_shots: int = 1024,
    qiskit_backend: Optional[AerSimulator] = None,
    num_evals_per_step: int = 1,
    maxiters: int = 100,
    tol: float = 1e-3,
    return_all: bool = True,
    if_print: bool = False
) -> None:

    optim_kwargs = {
        "return_all": return_all, "max_iters": maxiters, "tol": tol, "if_print": if_print
    }
    prev_params = np.array([])

    with open(f"{outputs_dirname:s}/loss-{time.ctime(time.time())}.txt", "a") as fhandle:
        fhandle.write("----------------- Start -----------------\n")
        for step_indx in range(1, 1+num_qubits):
            step_data_cache = {}
            for n in range(num_evals_per_step):
                if step_indx < 4:
                    new_init_params = init_params[4*(step_indx-1):4*step_indx]
                else:
                    new_init_params = init_params[-1:]
                loss_v, cir, loss_history = single_step_optimization(
                    step_indx, num_qubits, prev_params, new_init_params, num_shots, qiskit_backend, optim_kwargs
                )
                step_data_cache[n] = {"loss": loss_v, "param": cir.get_params(), "loss_path": loss_history}
            min_loss_indx = min(list(step_data_cache.items()), key=lambda x: x[1]["loss"])[0]
            prev_params = np.concatenate((prev_params, step_data_cache[min_loss_indx]["param"]))
            cur_loss = step_data_cache[min_loss_indx]["loss"]
            logging.info(f"STEP {step_indx:d}:")
            logging.info(f"loss={cur_loss:.5f}")
            logging.info(f"circuit parameters={prev_params}")

            saved_loss_history = step_data_cache[min_loss_indx]["loss_path"]
            fhandle.write(" ".join(map(str, saved_loss_history)))
            fhandle.write("\n")
        # calculate fidelity
        ## build GHZsimcir
        psisim_cir = GHZSimCircuit(num_qubits)
        for _ in range(num_qubits):
            psisim_cir.add_block()
        psisim_cir.set_all_params(prev_params)
        ## calculate fidelity
        fval = fidelity(psisim_cir, num_shots, qiskit_backend)
        logging.info(f"Fidelity={fval:.5f}")
        fhandle.write("----------------- End -----------------")
