from typing import Dict
from pathlib import Path
import logging
import time
import toml
import argparse
import numpy as np
import QCompute
from qiskit_aer import AerSimulator
from noise_model import noise_model1, noise_model2
from utils import quantum_state_learning

QCompute.Define.Settings.outputInfo = False
logging.basicConfig(filename=f"log-{time.ctime(time.time())}", filemode="w", format="%(message)s", level=logging.INFO)
logging.getLogger("qiskit").setLevel(logging.CRITICAL+1)
NOISEMODEL = {1: noise_model1, 2: noise_model2}


def main(args):
    time_start = time.strftime("%Y%m%d-%H:%M:%S", time.localtime())
    logging.info(f"Job start at {time_start:s}")

    # load settings
    parsed_configs: Dict = toml.load(args.config)
    logging.info("################ init settings ################")
    logging.info(toml.dumps(parsed_configs))

    # configure noise simulator
    model_fn = NOISEMODEL[parsed_configs["MODELTYPE"]]
    model = model_fn(
        parsed_configs["DEPOLARIZING_LAMBDA"],
        parsed_configs["T1"],
        parsed_configs["T2"],
        parsed_configs["GATETIME"],
        parsed_configs["GATETIME"],
        parsed_configs["Q3GATETIME"],
        parsed_configs["READOUTPROB_00"],
        parsed_configs["READOUTPROB_11"]
    )
    ibmsim = AerSimulator(noise_model=model)

    # learn
    logging.info("################ learning process ################")
    num_params = 4*parsed_configs["NUMQUBITS"] - 3
    init_params = np.zeros(num_params)
    timestamp = time.ctime(time.time())
    Path(f"outputs/{timestamp:s}").mkdir(parents=True, exist_ok=True)
    for n in range(args.Nrounds):
        logging.info(f"################ {n:d} ################")
        quantum_state_learning(
            init_params,
            parsed_configs["NUMQUBITS"],
            f"outputs/{timestamp:s}",
            parsed_configs["NUMSHOTS"],
            ibmsim,
            parsed_configs["NUMSTEPEVALS"],
            parsed_configs["MAXITERS"],
            parsed_configs["TOL"]
        )

    time_stop = time.strftime("%Y%m%d-%H:%M:%S", time.localtime())
    logging.info(f"\nJob end at {time_stop:s}\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Quantum State Learning.")
    parser.add_argument("--Nrounds", type=int, help="Number of times to run the quantum state learning")
    parser.add_argument("--config", type=str, help="Input the config file with toml format.")
    main(parser.parse_args())
