from typing import List
import argparse
import time
import numpy as np
from project_qsl import GHZSimCircuit, QCircuit
import matplotlib.pyplot as plt
import QCompute
QCompute.Define.Settings.outputInfo = False


def ghz_counts(num_qubits: int, measured_qindx: List[int], num_shots: int = 1024):
    cir = QCircuit(num_qubits)
    cir.h(0)
    for x in range(1, num_qubits):
        cir.cx(0, x)
    return cir.run(nshots=num_shots, measure_qindices=measured_qindx)


def main(args):
    num_qubits = 4
    f = args.fidelity
    params = np.fromstring(next(args.params), sep=" ")
    num_shots = args.numshots
    
    fig = plt.figure(figsize=args.figsize, dpi=300)
    ax = fig.gca()
    bar_width = 0.3
    countsghz = ghz_counts(4, list(range(4)), num_shots)

    simcir = GHZSimCircuit(num_qubits)    
    for _ in range(4):
        simcir.add_block()
    simcir.set_all_params(params)
    countsim = simcir.run(num_shots, list(range(4)))
    for state in countsim.keys():
        if state not in countsghz:
            countsghz[state] = 0

    num_items = len(countsim)
    ghzstate_label, ghzstate_counts = list(zip(*sorted(countsghz.items())))
    simstate_label, simstate_counts = list(zip(*sorted(countsim.items())))
    ax.bar(np.arange(num_items), simstate_counts, bar_width, label="Noisy QSSM")
    ax.bar(np.arange(num_items)+bar_width, ghzstate_counts, bar_width, label="Standard GHZ")
    ax.set_xlabel("state strings", fontfamily = 'Times New Roman', weight='normal',fontsize=13)
    ax.set_ylabel("Shots", fontfamily = 'Times New Roman', weight='normal',fontsize=13)
    ax.set_title(f"(b) Distribution of learned QSSM (fidelity={f:.3f})", fontsize=15, fontfamily = 'Times New Roman', weight='normal', pad=20)
    ax.set_xticks(np.arange(num_items)+0.5*bar_width, simstate_label, rotation=40)
    ax.legend(prop={'family': 'Times New Roman', 'weight':'normal', 'size':13})
    fig.savefig(f"QSL-performance-{time.ctime(time.time())}.pdf")

    
if __name__ == "__main__":
    import sys

    parser = argparse.ArgumentParser(description="QSL results show")
    parser.add_argument("params", type=argparse.FileType("r"), nargs="?", default=sys.stdin)
    parser.add_argument("--fidelity", type=float)
    parser.add_argument("--numshots", type=int, default=1024)
    parser.add_argument("--figsize", nargs=2, type=float, default=[6.4, 4.8])
    main(parser.parse_args())
