from turtle import title
from qram import QRAMDataSet, RowPatternDataset, MNISTDataset
from qmodel import QModel, RowdetectModel, RowdetectModelA, RowdetectModelB, MNISTMODEL
from qoptimizer import QOptimizer
from qiskit import QuantumCircuit
import matplotlib.pyplot as plt
import torch
from qiskit.visualization import plot_histogram,plot_state_city
from qiskit import IBMQ, Aer, assemble, transpile, execute
import time
import numpy as np

def verification(s):
    
            ACC = []
            TACC = []
            torch.set_printoptions (profile="full") 
            W = [i for i in range(1024)]
            Weights = []
            for w in W:
                w = list(bin(w)[2:].rjust(10, '0'))
                for i in range(len(w)):
                    w[i] =  int((w[i] == '1'))
                Weights.append(w)
            train_data = torch.load('./mnist_dataset/train_data2_7.pt')
            train_label = torch.load('./mnist_dataset/train_label2_7.pt')
            test_data = torch.load('./mnist_dataset/test_data2_7.pt')
            test_label = torch.load('./mnist_dataset/test_label2_7.pt')
            TRAIN_DATA = []
            TEST_DATA = []
            for i in range(train_data.shape[0]):
                x = train_data[i]
                x = list(bin(x)[2:].rjust(9, '0'))
                for j in range(len(x)):
                    x[j] =  int((x[j] == '1'))
                TRAIN_DATA.append(x)
            for i in range(test_data.shape[0]):
                x = test_data[i]
                x = list(bin(x)[2:].rjust(9, '0'))
                for j in range(len(x)):
                    x[j] =  int((x[j] == '1'))
                TEST_DATA.append(x)
            for k in range(len(Weights)):
                w = Weights[k]
                train_accuracy = 0
                test_accuracy = 0
                for i in range(len(TRAIN_DATA)):
                    sample = TRAIN_DATA[i]
                    o = 0
                    for j in range(9):
                        o += w[j] * sample[j]
                    o = int(o>=1)^w[9]
                    if o == train_label[i]:
                        train_accuracy += 1
                for i in range(len(TEST_DATA)):
                    sample = TEST_DATA[i]
                    o = 0
                    for j in range(9):
                        o += w[j] * sample[j]
                    o = int(o>=1)^w[9]
                    if o == test_label[i]:
                        test_accuracy += 1
                ACC.append(train_accuracy)
                TACC.append(test_accuracy)
                # print("{}:  {:.3f}, {:.3f} ".format(w,train_accuracy/len(TRAIN_DATA),test_accuracy/len(TEST_DATA)))
            ACC = torch.Tensor(ACC).float()
            TACC = torch.Tensor(TACC).float()
            optimal = list(range(42))
            others = list(range(42, 1024))
            
            ACC = torch.pow(ACC, 1)
            S_ACC = ACC.sort(descending=True)
            
            theta = torch.arcsin(torch.sqrt(ACC.mean()/(128)))
            print(theta)

            
            print(ACC.max())
            print(ACC.mean())
            exit()
            Mis = 128 - S_ACC[0]
            
            
            pos = torch.zeros((1024))
                
            rot = (1) * theta
            for j in range(1024):
                pos[j] = torch.sin(rot) * torch.sin(rot) * S_ACC[0][j] / S_ACC[0].sum() + torch.cos(rot) * torch.cos(rot) * Mis[j] / Mis.sum()
            x = torch.zeros((1024))
            for j in range(1024):
                x[j] = pos[0:j].sum()
            x = 1.0 - x
            
            x = torch.pow(x, s)

            y = torch.zeros((1024))
            for j in range(1023):
                y[j] = x[j] - x[j+1]
            y[1023] = x[1023]
            ex = 0
            ex2 = 0
            for j in range(1023):
                ex += y[j] * torch.pow(S_ACC[0][j],1) / 94
                ex2 += y[j] * torch.pow(S_ACC[0][j],1) * torch.pow(S_ACC[0][j],1) / (94*94)
            var = ex2 - ex * ex
            print(ex)
            print(var)
            t_ex = 0
            t_ex2 = 0
            for j in range(256):
                t_ex += y[j] * TACC[S_ACC[1][j]] / 56
                t_ex2 += y[j] * TACC[S_ACC[1][j]] * TACC[S_ACC[1][j]] / (56 * 56)
            t_var = t_ex2 - t_ex * t_ex
            print(t_ex)
            print(t_var)
            return ex, var, t_ex, t_var
a = []
b = []
c = []
d = []
verification(0)
shots = [1, 2, 4, 8, 16, 32, 64, 128, 256]
for s in shots:
    print("---------------------{}--------------".format(s))
    ex, var, t_ex, t_var = verification(s)
    a.append(ex)
    b.append(var)
    c.append(t_ex)
    d.append(t_var)
print(a)
print(b)
print(c)
print(d)