from turtle import color, title
from qram import QRAMDataSet, RowPatternDataset
from qmodel import QModel, RowdetectModel, RowdetectModelA, RowdetectModelB,ConvolutionLayer
from qoptimizer import QOptimizer
from qiskit import QuantumCircuit
import matplotlib.pyplot as plt
import torch
from qiskit.visualization import plot_histogram,plot_state_city
from qiskit import IBMQ, Aer, assemble, transpile, execute
import time
def verification(s=0):
    data = torch.load('./row_pattern_dataset/row_pattern_full.pt')
    label = torch.load('./row_pattern_dataset/row_pattern_labels_full.pt')
    W = [i for i in range(256)]
    Weights = []
    for w in W:
        w = list(bin(w)[2:].rjust(8, '0'))
        for i in range(len(w)):
            w[i] =  int((w[i] == '1'))
        Weights.append(w)
    Accuracy = []
    for j in range(len(Weights)):
        w = Weights[j]
        predict_w = 0
        for i in range(400):
            sample = data[i]
            x = int((sample[0][0] ^ w[0]) * (sample[0][1] ^ w[1])*(sample[0][2] ^ w[2]) + \
                (sample[1][0] ^ w[0]) * (sample[1][1] ^ w[1])*(sample[1][2] ^ w[2]) + \
                (sample[2][0] ^ w[0]) * (sample[2][1] ^ w[1])*(sample[2][2] ^ w[2]) >= 1)
            y = int((sample[0][0] ^ w[3]) * (sample[1][0] ^ w[4])*(sample[2][0] ^ w[5]) + \
                (sample[0][1] ^ w[3]) * (sample[1][1] ^ w[4])*(sample[2][1] ^ w[5]) + \
                (sample[0][2] ^ w[3]) * (sample[1][2] ^ w[4])*(sample[2][2] ^ w[5]) >= 1)
            o =  (w[6] ^ x) * (w[7] ^ y)
            if o == label[i]:
                predict_w += 1
        Accuracy.append(predict_w)
    
    Accuracy = torch.Tensor(Accuracy).float()
    
    test_Accuracy = []
    for j in range(len(Weights)):
        w = Weights[j]
        predict_w = 0
        for i in range(400, 512):
            sample = data[i]
            x = int((sample[0][0] ^ w[0]) * (sample[0][1] ^ w[1])*(sample[0][2] ^ w[2]) + \
                (sample[1][0] ^ w[0]) * (sample[1][1] ^ w[1])*(sample[1][2] ^ w[2]) + \
                (sample[2][0] ^ w[0]) * (sample[2][1] ^ w[1])*(sample[2][2] ^ w[2]) >= 1)
            y = int((sample[0][0] ^ w[3]) * (sample[1][0] ^ w[4])*(sample[2][0] ^ w[5]) + \
                (sample[0][1] ^ w[3]) * (sample[1][1] ^ w[4])*(sample[2][1] ^ w[5]) + \
                (sample[0][2] ^ w[3]) * (sample[1][2] ^ w[4])*(sample[2][2] ^ w[5]) >= 1)
            o =  (w[6] ^ x) * (w[7] ^ y)
            if o == label[i]:
                predict_w += 1
        test_Accuracy.append(predict_w)
    
    test_Accuracy = torch.Tensor(test_Accuracy).float()
    print(test_Accuracy.mean())
    Accuracy = torch.pow(Accuracy, 1)
    theta = torch.arcsin(torch.sqrt(Accuracy.mean()/(1024)))
    print(theta)
    
    
    shots = s
    sorted_Accuracy = Accuracy.sort(descending=True)
    print(sorted_Accuracy[0])
    print(sorted_Accuracy[1])    
    pos = torch.zeros((256))
    Mis :torch.Tensor = 512   - sorted_Accuracy[0]
    rot = (3) * theta
    
    
    for j in range(256):
        
        pos[j] = torch.sin(rot) * torch.sin(rot) * sorted_Accuracy[0][j] / sorted_Accuracy[0].sum() + torch.cos(rot) * torch.cos(rot) * Mis[j] / Mis.sum()
    
    x = torch.zeros((256))
    for j in range(256):
        x[j] = pos[0:j].sum()
    x = 1.0 - x
    
    x = torch.pow(x, shots)

    y = torch.zeros((256))
    for j in range(255):
        y[j] = x[j] - x[j+1]
    y[255] = x[255]
    ex = 0
    ex2 = 0
    for j in range(256):
        ex += y[j] * torch.pow(sorted_Accuracy[0][j], 1)/ (400)
        ex2 += y[j] * torch.pow(sorted_Accuracy[0][j],1) * torch.pow(sorted_Accuracy[0][j], 1) / (400*400)
    var = ex2 - ex * ex
    print(ex)
    print(var)
    t_ex = 0
    t_ex2 = 0
    for j in range(256):
        t_ex += y[j] * test_Accuracy[sorted_Accuracy[1][j]] / 112
        t_ex2 += y[j] * test_Accuracy[sorted_Accuracy[1][j]] * test_Accuracy[sorted_Accuracy[1][j]] / (112 * 112)
    t_var = t_ex2 - t_ex * t_ex
    print(t_ex)
    print(t_var)
        
    
        
    
    return ex, var, t_ex, t_var
import math
EX = []
VAR = []
T_EX = []
T_VAR = []
shots = [1,2,4,8,16,32,64,128]
for s in shots:
    ex, var, t_ex, t_var = verification(s)
    EX.append(ex)
    VAR.append(math.sqrt(var))
    T_EX.append(t_ex)
    T_VAR.append(math.sqrt(t_var))
print(EX)
print(VAR)
print(T_EX)
print(T_VAR)