from turtle import color, title
from qram import QRAMDataSet, RowPatternDataset
from qmodel import QModel, RowdetectModel, RowdetectModelA, RowdetectModelB,ConvolutionLayer
from qoptimizer import QOptimizer
from qiskit import QuantumCircuit
import matplotlib.pyplot as plt
import torch
from qiskit.visualization import plot_histogram,plot_state_city
from qiskit import IBMQ, Aer, assemble, transpile, execute
import time
def verification():
    data = torch.load('./row_pattern_dataset/row_pattern_full.pt')
    label = torch.load('./row_pattern_dataset/row_pattern_labels_full.pt')
    W = [i for i in range(256)]
    Weights = []
    for w in W:
        w = list(bin(w)[2:].rjust(8, '0'))
        for i in range(len(w)):
            w[i] =  int((w[i] == '1'))
        Weights.append(w)
    Accuracy = []
    for j in range(len(Weights)):
        w = Weights[j]
        predict_w = 0
        for i in range(400):
            sample = data[i]
            x = int((sample[0][0] ^ w[0]) * (sample[0][1] ^ w[1])*(sample[0][2] ^ w[2]) + \
                (sample[1][0] ^ w[0]) * (sample[1][1] ^ w[1])*(sample[1][2] ^ w[2]) + \
                (sample[2][0] ^ w[0]) * (sample[2][1] ^ w[1])*(sample[2][2] ^ w[2]) >= 1)
            y = int((sample[0][0] ^ w[3]) * (sample[1][0] ^ w[4])*(sample[2][0] ^ w[5]) + \
                (sample[0][1] ^ w[3]) * (sample[1][1] ^ w[4])*(sample[2][1] ^ w[5]) + \
                (sample[0][2] ^ w[3]) * (sample[1][2] ^ w[4])*(sample[2][2] ^ w[5]) >= 1)
            o =  (w[6] ^ x) * (w[7] ^ y)
            if o == label[i]:
                predict_w += 1
        Accuracy.append(predict_w)
    for i in range(len(Accuracy)):
        print("{} : {}".format(Weights[i], Accuracy[i]))
    Accuracy = torch.Tensor(Accuracy).float()
    print(Accuracy.mean())
    print(Accuracy.max())
    optimal = []
    for k in range(256):
        if Accuracy[k] > 340:
            optimal.append(k)
    print('\n\n')
    test_Accuracy = []
    for j in range(len(Weights)):
        w = Weights[j]
        predict_w = 0
        for i in range(400, 512):
            sample = data[i]
            x = int((sample[0][0] ^ w[0]) * (sample[0][1] ^ w[1])*(sample[0][2] ^ w[2]) + \
                (sample[1][0] ^ w[0]) * (sample[1][1] ^ w[1])*(sample[1][2] ^ w[2]) + \
                (sample[2][0] ^ w[0]) * (sample[2][1] ^ w[1])*(sample[2][2] ^ w[2]) >= 1)
            y = int((sample[0][0] ^ w[3]) * (sample[1][0] ^ w[4])*(sample[2][0] ^ w[5]) + \
                (sample[0][1] ^ w[3]) * (sample[1][1] ^ w[4])*(sample[2][1] ^ w[5]) + \
                (sample[0][2] ^ w[3]) * (sample[1][2] ^ w[4])*(sample[2][2] ^ w[5]) >= 1)
            o =  (w[6] ^ x) * (w[7] ^ y)
            if o == label[i]:
                predict_w += 1
        test_Accuracy.append(predict_w)
    for i in range(len(test_Accuracy)):
        print("{} : {}".format(Weights[i], test_Accuracy[i]))
    test_Accuracy = torch.Tensor(test_Accuracy).float()
    print(test_Accuracy.mean())
    for j in range(256):
        Accuracy[j] = Accuracy[j]
    theta = torch.arcsin(torch.sqrt(Accuracy.mean()/(1024)))
    print(theta)
    
    Mis = 1024  - Accuracy
    others = []
    for a in range(256):
        if a not in optimal:
            others.append(a)
    print(optimal)
    for i in range(8):
        pos = torch.zeros((256))
        plt.title('Weights Sample Possibility Distribution')
        plt.xlabel('weights in dictionary order')
        plt.ylabel('weights sample possibility')
        rot = (2 * i + 1) * theta
        for j in range(256):
            pos[j] = torch.sin(rot) * torch.sin(rot) * Accuracy[j] / Accuracy.sum() + torch.cos(rot) * torch.cos(rot) * Mis[j] / Mis.sum()
        plt.bar(others, pos[others],fc='b')
        plt.bar(optimal, pos[optimal], fc='r')
        print(pos.sum()/pos[optimal].sum())
        plt.savefig('weights_sample_possibility_after_grover{}.pdf'.format(i))
        plt.clf()
    
    return Accuracy, test_Accuracy
def Optimize():
    weights_bits = [0]
    address_bits = list(range(1, 10))
    dataset_bits = list(range(1, 12))
    ancillas = [12, 13, 14, 15]
    output = 16
    allqubits = list(range(17))
    qc = QuantumCircuit(17, 17)
    train = torch.Tensor([1 for _ in range(512)]).long()
    row_pattern_dataset = RowPatternDataset(address_qubits=address_bits,qc=qc,dataset_qubits=dataset_bits,train=train)
    row_pattern_model = RowdetectModel(qc=qc,ancilla_bits=ancillas,dataset_qubits=dataset_bits, weights=weights_bits,output=output)
    optimizer = QOptimizer(qc=qc,dataset_qubits=dataset_bits,output=output,data=row_pattern_dataset,model=row_pattern_model,allqubits=allqubits)
    row_pattern_dataset.encode()
    row_pattern_model.forward()
    optimizer.optimize(iter=3)
    qc.measure(weights_bits, weights_bits)
    aer_sim = Aer.get_backend('aer_simulator')
    job = execute(qc, aer_sim, shots=1000000)
    counts = job.result().get_counts()
    print(counts)
    fig = plot_histogram(counts)
    fig.savefig("row_pattern.jpg")
def Optimize_A():
    weights_bits = [0, 1, 2]
    address_bits = list(range(3, 12))
    dataset_bits = list(range(3, 14))
    ancillas = [14, 15, 16, 17]
    output = 18
    allqubits = list(range(19))
    qc = QuantumCircuit(19, 3)
    train = torch.Tensor([1 for _ in range(512)]).long()
    row_pattern_dataset = RowPatternDataset(address_qubits=address_bits,qc=qc,dataset_qubits=dataset_bits,train=train)
    row_pattern_model = RowdetectModelA(qc=qc,ancilla_bits=ancillas,dataset_qubits=dataset_bits, weights=weights_bits,output=output)
    optimizer = QOptimizer(qc=qc,dataset_qubits=dataset_bits,output=output,data=row_pattern_dataset,model=row_pattern_model,allqubits=allqubits)
    row_pattern_dataset.encode()
    row_pattern_model.forward()
    optimizer.optimize(iter=3)
    qc.measure(weights_bits, weights_bits)
    aer_sim = Aer.get_backend('aer_simulator')
    job = execute(qc, aer_sim, shots=1000000)
    counts = job.result().get_counts()
    print(counts)
    fig = plot_histogram(counts)
    fig.savefig("row_pattern_A.jpg")
def Optimize_B(iter):
    weights_bits = [0, 1, 2]
    address_bits = list(range(3, 12))
    dataset_bits = list(range(3, 14))
    ancillas = [14, 15, 16, 17, 18, 19, 20, 21]
    output = 22
    allqubits = list(range(23))
    qc = QuantumCircuit(23, 3)
    train = torch.Tensor([1 for _ in range(512)]).long()
    row_pattern_dataset = RowPatternDataset(address_qubits=address_bits,qc=qc,dataset_qubits=dataset_bits,train=train)
    row_pattern_model = RowdetectModelB(qc=qc,ancilla_bits=ancillas,dataset_qubits=dataset_bits, weights=weights_bits,output=output)
    optimizer = QOptimizer(qc=qc,dataset_qubits=dataset_bits,output=output,data=row_pattern_dataset,model=row_pattern_model,allqubits=allqubits)
    row_pattern_dataset.encode()
    row_pattern_model.forward()
    optimizer.optimize(iter=iter)
    qc.measure(weights_bits, weights_bits)
    aer_sim = Aer.get_backend('aer_simulator')
    aer_sim.set_options(precision='single')
    transpiled_qc = transpile(qc, aer_sim)
    job = execute(transpiled_qc, aer_sim, shots=1000000)
    counts = job.result().get_counts()
    print(counts)
def Optimize_C(iter=0,shots=1e6):
    weights_bits = [0, 1, 2, 3, 4, 5, 6, 7]
    address_bits = list(range(8, 18))
    dataset_bits = list(range(8, 19))
    ancillas = [19, 20, 21, 22, 23, 24, 25, 26]
    output = 27
    allqubits = list(range(28))
    qc = QuantumCircuit(28, 8)
    train = torch.Tensor([1 for _ in range(1024)]).long()
    train[400: 1024] = 0
    row_pattern_dataset = RowPatternDataset(address_qubits=address_bits,qc=qc,dataset_qubits=dataset_bits,train=train)
    row_pattern_model = ConvolutionLayer(qc=qc,ancilla_bits=ancillas,dataset_qubits=dataset_bits, weights=weights_bits,output=output)
    optimizer = QOptimizer(qc=qc,dataset_qubits=dataset_bits,output=output,data=row_pattern_dataset,model=row_pattern_model,allqubits=allqubits)
    row_pattern_dataset.encode()
    row_pattern_model.forward()
    # row_pattern_model.de_forward()
    # row_pattern_dataset.decode()
    # qc.measure_all()
    optimizer.optimize(iter=iter)
    qc.measure(weights_bits, weights_bits)
    fig = qc.draw('mpl')
    fig.savefig('qc.pdf')
    exit()
    aer_sim = Aer.get_backend('aer_simulator')
    aer_sim.set_options(precision='single')
    transpiled_qc = transpile(qc, aer_sim)
    job = execute(transpiled_qc, aer_sim, shots=shots)
    counts = job.result().get_counts()
    
    return counts

train_acc , test_acc = verification()
shots = [1, 2, 4, 8, 16, 32, 64, 128]
RUNS = 10
PerFormance = torch.zeros((RUNS, len(shots)))
T_PerFormance = torch.zeros((RUNS, len(shots)))
for run in range(10):
    
    for i in range(len(shots)):
        x = shots[i]
        counts = Optimize_C(iter=1,shots=x)
        keys = list(counts.keys())
        MAX_ID = int(keys[0], 2)
        for it in keys:
            number = int(it, 2)
            if train_acc[number] > train_acc[MAX_ID]:
                    MAX_ID = number
        PerFormance[run][i] = train_acc[MAX_ID]
        T_PerFormance[run][i] = test_acc[MAX_ID]
       
print(PerFormance)
print(T_PerFormance)

PerFormance = PerFormance.t()
T_PerFormance = T_PerFormance.t()

for i in range(len(shots)):
    print(PerFormance[i].mean())
    print(T_PerFormance[i].mean())
    print(PerFormance[i].var())
    print(T_PerFormance[i].var())



