from cProfile import label
from cgitb import handler
from genericpath import exists
import itertools
from random import random, seed, shuffle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
import os
from tqdm import tqdm
from matplotlib import cm
from matplotlib.collections import LineCollection
import matplotlib.pyplot as plt
from matplotlib import pyplot
plt.style.use('seaborn-white')
palette = pyplot.get_cmap('Set1')
fontdict = {'family':'Times New Roman', 'size': 40,'weight': 'bold'}
fontdict1 = {'family':'Times New Roman', 'size':40,'weight': 'bold'}
plt.rc('font',family='Times New Roman', size=40)
fig_size = plt.rcParams["figure.figsize"]
plt.rcParams["axes.labelweight"] = "bold"
fig_size[0] = 12
fig_size[1] = 8
plt.rcParams["figure.figsize"] = fig_size
a = [0, 64, 128, 192, 256]
samples : torch.Tensor = torch.load('./data/edge_detection_full.pt')   
labels : torch.Tensor = torch.load('./data/edge_detection_full_label.pt')
train_data = samples[0:400]
train_label = labels[0:400]
test_data = samples[400:512]
test_label = labels[400:512]
TIMES = 1024
ACC = 0.95
ACCX = 0.95
def model(data : torch.Tensor, w : list):
    o1 = int((data[0][0] ^ w[0]) * (data[0][1] ^ w[1]) * (data[0][2] ^ w[2]) + \
            (data[1][0] ^ w[0]) * (data[1][1] ^ w[1]) * (data[1][2] ^ w[2]) + \
                (data[2][0] ^ w[0]) * (data[2][1] ^ w[1]) * (data[2][2] ^ w[2]) >= 1)^w[6]
    o2 =  int((data[0][0] ^ w[3]) * (data[1][0] ^ w[4]) * (data[2][0] ^ w[5]) + \
            (data[0][1] ^ w[3]) * (data[1][1] ^ w[4]) * (data[2][1] ^ w[5]) + \
                (data[0][2] ^ w[3]) * (data[1][2] ^ w[4]) * (data[2][2] ^ w[5]) >=1)^w[7]
    return o1, o2
def calculate():
    if os.path.exists('./accuracy/train.pt'):
        train_accuracy = torch.load('./accuracy/train.pt')
        test_accuracy = torch.load('./accuracy/test.pt')
        return train_accuracy, test_accuracy
    W = [i for i in range(256)]
    Weights = []
    for w in W:
        w = list(bin(w)[2:].rjust(8, '0'))
        for i in range(len(w)):
            w[i] =  int((w[i] == '1'))
        Weights.append(w)
    train_accuracy = torch.zeros((256))
    test_accuracy = torch.zeros((256))

    for x in range(256):
        w = Weights[x]
        for i in range(400):
            o1, o2 = model(train_data[i], w)
            if (o1 == train_label[i][0]) and (o2 == train_label[i][1]):
                train_accuracy[x] += 1
        for j in range(112):
            o1, o2 = model(test_data[j], w)
            if (o1 == test_label[j][0]) and (o2 == test_label[j][1]):
                test_accuracy[x] += 1
    torch.save(train_accuracy,'./accuracy/train.pt')
    torch.save(test_accuracy,'./accuracy/test.pt')
    return train_accuracy, test_accuracy
train_accuracy, test_accuracy = calculate()
def accuracy_distribution():
    norm = plt.Normalize(train_accuracy.min() - 800, train_accuracy.max())
    sorted_train_accuracy = train_accuracy.sort(descending=True)
    
    map_vir = cm.get_cmap(name='OrRd')
    color = map_vir(norm(sorted_train_accuracy[0]))
    
    
    plt.xlabel('Weight',fontdict=fontdict)
    plt.ylabel('Accuracy',fontdict=fontdict)
    
    plt.bar(list(range(256)),sorted_train_accuracy[0] / 400, color=color)
    plt.savefig('./img/new8/edge_detection_accuracy_distribution.pdf',pad_inches=0.5)
    plt.clf()
def numerical_simulate_initialize(k=1):
    norm = plt.Normalize(train_accuracy.min() - 800, train_accuracy.max())
    sorted_train_accuracy = train_accuracy.sort(descending=True)
    map_vir = cm.get_cmap(name='OrRd')
    color = map_vir(norm(sorted_train_accuracy[0]))

    search_space = 512 ** k
    train_accuracy_kpd = torch.pow(sorted_train_accuracy[0], k)
    theta = torch.arcsin(torch.sqrt(train_accuracy_kpd.mean()/(search_space)))
    rot = (1) * theta
    miss_kpd :torch.Tensor = search_space - train_accuracy_kpd
    pos = torch.zeros((256))
    for j in range(256):
        pos[j] = torch.sin(rot) * torch.sin(rot) * train_accuracy_kpd[j] / train_accuracy_kpd.sum() + torch.cos(rot) * torch.cos(rot) * miss_kpd[j] / miss_kpd.sum()
    #plt.ticklabel_format(style='sci',scilimits=(0,0),axis='y')
    plt.xlabel('Weight',fontdict=fontdict)
    plt.ylabel('Measuring Probability',fontdict=fontdict)
    plt.subplots_adjust(left=0.16, bottom=0.14, right=0.84, top=0.96,hspace=0.5,wspace=0.5)
    plt.xticks(a)
    plt.ylim(0, 0.016)
    x, = plt.plot(sorted_train_accuracy[0]/sorted_train_accuracy[0].sum(),color='darkorange',dashes=[6,2])
    plt.legend([x],['normalized accuracy'], prop=fontdict1)
    plt.bar(list(range(256)), pos, color='royalblue')
    plt.savefig('./img/new8/edge_detection_initial_state.pdf')
    plt.clf()
    
    
    pos_accumulate = torch.zeros((256))
    for j in range(256):
        pos_accumulate[j] = pos[0:j].sum()
    oppose = 1.0 - pos_accumulate
    shots = list(range(1, TIMES + 1))
    train_accuracy_ex = torch.zeros((TIMES))
    test_accuracy_ex = torch.zeros((TIMES))
    train_accuracy_var = torch.zeros((TIMES))
    test_accuracy_var = torch.zeros((TIMES))
    for i in range(len(shots)):
        s = shots[i]
        x = torch.pow(oppose, s)
        y = torch.zeros((256))
        for k in range(255):
            y[k] = x[k] - x[k+1]
        y[255] = x[255]
        ex = 0
        ex2 = 0
        for j in range(256):
                ex += y[j] * sorted_train_accuracy[0][j] / 400
                ex2 += y[j] * sorted_train_accuracy[0][j] * sorted_train_accuracy[0][j] /(400 * 400)
        var = ex2 - ex * ex
        train_accuracy_ex[i] = ex
        train_accuracy_var[i] = var
        t_ex = 0
        t_ex2 = 0
        for j in range(256):
                t_ex += y[j] * test_accuracy[sorted_train_accuracy[1][j]] / 112
                t_ex2 += y[j] * test_accuracy[sorted_train_accuracy[1][j]] * test_accuracy[sorted_train_accuracy[1][j]] / (112* 112)
        t_var = t_ex2 - t_ex * t_ex
        test_accuracy_ex[i] = t_ex
        test_accuracy_var[i] = t_var
    # print(train_accuracy_ex)
    # print(test_accuracy_ex)
    # print(train_accuracy_var)
    # print(test_accuracy_var)
    plt.title('shots-expected-accuracy-relationship')
    plt.ylabel('expected accuracy')
    plt.xlabel('shots')
    plt.plot(shots, train_accuracy_ex)
    plt.savefig('./img/uniform_sampling.pdf')
    plt.clf()
    print((test_accuracy_ex > ACCX).nonzero()[0])
    for i in range(0, TIMES):
        if train_accuracy_ex[i] >= ACC:
            print("needs {} times of run to get a near-optimal weight".format(i))
            break
    return train_accuracy_ex, test_accuracy_ex, train_accuracy_var, test_accuracy_var
def numerical_simulate_1PD(k=1):
    norm = plt.Normalize(train_accuracy.min() - 800, train_accuracy.max())
    sorted_train_accuracy = train_accuracy.sort(descending=True)
    map_vir = cm.get_cmap(name='OrRd')
    color = map_vir(norm(sorted_train_accuracy[0]))

    search_space = 512 ** k
    train_accuracy_kpd = torch.pow(sorted_train_accuracy[0], k)
    theta = torch.arcsin(torch.sqrt(train_accuracy_kpd.mean()/(search_space)))
    rot = (3) * theta
    miss_kpd :torch.Tensor = search_space - train_accuracy_kpd
    pos = torch.zeros((256))
    for j in range(256):
        pos[j] = torch.sin(rot) * torch.sin(rot) * train_accuracy_kpd[j] / train_accuracy_kpd.sum() + torch.cos(rot) * torch.cos(rot) * miss_kpd[j] / miss_kpd.sum()
    #plt.ticklabel_format(style='sci',scilimits=(0,0),axis='y')
    plt.xlabel('Weight',fontdict=fontdict)
    plt.ylabel('Measuring Probability',fontdict=fontdict)
    plt.xticks(a)
    plt.subplots_adjust(left=0.16, bottom=0.14, right=0.84, top=0.96,hspace=0.5,wspace=0.5)
    
    plt.ylim(0, 0.016)
    x, = plt.plot(sorted_train_accuracy[0]/sorted_train_accuracy[0].sum(),color='darkorange',dashes=[6,2])
    plt.legend([x],['normalized accuracy'], prop=fontdict1)
   
    plt.bar(list(range(256)), pos, color='royalblue')
    plt.savefig('./img/new8/edge_detection_1PD.pdf',pad_inches=0.5)
    plt.clf()
    

    pos_accumulate = torch.zeros((256))
    for j in range(256):
        pos_accumulate[j] = pos[0:j].sum()
    oppose = 1.0 - pos_accumulate
    shots = list(range(1, TIMES + 1))
    train_accuracy_ex = torch.zeros((TIMES))
    test_accuracy_ex = torch.zeros((TIMES))
    train_accuracy_var = torch.zeros((TIMES))
    test_accuracy_var = torch.zeros((TIMES))
    for i in range(len(shots)):
        s = shots[i]
        x = torch.pow(oppose, s)
        y = torch.zeros((256))
        for k in range(255):
            y[k] = x[k] - x[k+1]
        y[255] = x[255]
        ex = 0
        ex2 = 0
        for j in range(256):
                ex += y[j] * sorted_train_accuracy[0][j] / 400
                ex2 += y[j] * sorted_train_accuracy[0][j] * sorted_train_accuracy[0][j] /(400 * 400)
        var = ex2 - ex * ex
        train_accuracy_ex[i] = ex
        train_accuracy_var[i] = var
        t_ex = 0
        t_ex2 = 0
        for j in range(256):
                t_ex += y[j] * test_accuracy[sorted_train_accuracy[1][j]] / 112
                t_ex2 += y[j] * test_accuracy[sorted_train_accuracy[1][j]] * test_accuracy[sorted_train_accuracy[1][j]] / (112* 112)
        t_var = t_ex2 - t_ex * t_ex
        test_accuracy_ex[i] = t_ex
        test_accuracy_var[i] = t_var
    # print(train_accuracy_ex)
    # print(test_accuracy_ex)
    # print(train_accuracy_var)
    # print(test_accuracy_var)
    plt.title('shots-expected-accuracy-relationship')
    plt.ylabel('expected accuracy')
    plt.xlabel('shots')
    plt.plot(shots, train_accuracy_ex)
    plt.savefig('./img/1PD.pdf')
    plt.clf()
    print((test_accuracy_ex > ACCX).nonzero()[0])
    for i in range(0, TIMES):
        if train_accuracy_ex[i] >= ACC:
            print("needs {} times of run to get a near-optimal weight".format(i))
            break
    return train_accuracy_ex, test_accuracy_ex, train_accuracy_var, test_accuracy_var
def numerical_simulate_2PD(k=2):
    norm = plt.Normalize(train_accuracy.min() - 800, train_accuracy.max())
    sorted_train_accuracy = train_accuracy.sort(descending=True)
    map_vir = cm.get_cmap(name='OrRd')
    color = map_vir(norm(sorted_train_accuracy[0]))

    search_space = 512 ** k
    train_accuracy_kpd = torch.pow(sorted_train_accuracy[0], k)
    theta = torch.arcsin(torch.sqrt(train_accuracy_kpd.mean()/(search_space)))
    
    rot = (5) * theta
    miss_kpd :torch.Tensor = search_space - train_accuracy_kpd
    pos = torch.zeros((256))
    for j in range(256):
        pos[j] = torch.sin(rot) * torch.sin(rot) * train_accuracy_kpd[j] / train_accuracy_kpd.sum() + torch.cos(rot) * torch.cos(rot) * miss_kpd[j] / miss_kpd.sum()
    #plt.ticklabel_format(style='sci',scilimits=(0,0),axis='y')
    plt.xlabel('Weight',fontdict=fontdict)
    plt.ylabel('Measuring Probability',fontdict=fontdict)
    plt.xticks(a)
    plt.subplots_adjust(left=0.16, bottom=0.14, right=0.84, top=0.96,hspace=0.5,wspace=0.5)
    x, = plt.plot(sorted_train_accuracy[0]/sorted_train_accuracy[0].sum(),color='darkorange',dashes=[6,2])
    plt.legend([x],['normalized accuracy'], prop=fontdict1)
   
    plt.bar(list(range(256)), pos, color='royalblue')
    plt.savefig('./img/new8/edge_detection_2PD.pdf',pad_inches=0.5)
    plt.clf()
    
    pos_accumulate = torch.zeros((256))
    for j in range(256):
        pos_accumulate[j] = pos[0:j].sum()
    oppose = 1.0 - pos_accumulate
    shots = list(range(1, TIMES + 1))
    train_accuracy_ex = torch.zeros((TIMES))
    test_accuracy_ex = torch.zeros((TIMES))
    train_accuracy_var = torch.zeros((TIMES))
    test_accuracy_var = torch.zeros((TIMES))
    for i in range(len(shots)):
        s = shots[i]
        x = torch.pow(oppose, s)
        y = torch.zeros((256))
        for k in range(255):
            y[k] = x[k] - x[k+1]
        y[255] = x[255]
        ex = 0
        ex2 = 0
        for j in range(256):
                ex += y[j] * sorted_train_accuracy[0][j] / 400
                ex2 += y[j] * sorted_train_accuracy[0][j] * sorted_train_accuracy[0][j] /(400 * 400)
        var = ex2 - ex * ex
        train_accuracy_ex[i] = ex
        train_accuracy_var[i] = var
        t_ex = 0
        t_ex2 = 0
        for j in range(256):
                t_ex += y[j] * test_accuracy[sorted_train_accuracy[1][j]] / 112
                t_ex2 += y[j] * test_accuracy[sorted_train_accuracy[1][j]] * test_accuracy[sorted_train_accuracy[1][j]] / (112* 112)
        t_var = t_ex2 - t_ex * t_ex
        test_accuracy_ex[i] = t_ex
        test_accuracy_var[i] = t_var
    # print(train_accuracy_ex)
    # print(test_accuracy_ex)
    # print(train_accuracy_var)
    # print(test_accuracy_var)
    plt.title('shots-expected-accuracy-relationship')
    plt.ylabel('expected accuracy')
    plt.xlabel('shots')
    plt.plot(shots, train_accuracy_ex)
    plt.savefig('./img/2PD.pdf')
    plt.clf()
    for i in range(0, TIMES):
        if train_accuracy_ex[i] >= ACC:
            print("needs {} times of run to get a near-optimal weight".format(i))
            break
    return train_accuracy_ex, test_accuracy_ex , train_accuracy_var, test_accuracy_var
def numerical_simulate_3PD(k=3):
    norm = plt.Normalize(train_accuracy.min() - 800, train_accuracy.max())
    sorted_train_accuracy = train_accuracy.sort(descending=True)
    map_vir = cm.get_cmap(name='OrRd')
    color = map_vir(norm(sorted_train_accuracy[0]))

    search_space = 512 ** k
    train_accuracy_kpd = torch.pow(sorted_train_accuracy[0], k)
    theta = torch.arcsin(torch.sqrt(train_accuracy_kpd.mean()/(search_space)))
    
    
    rot = (7) * theta
    miss_kpd :torch.Tensor = search_space - train_accuracy_kpd
    pos = torch.zeros((256))
    for j in range(256):
        pos[j] = torch.sin(rot) * torch.sin(rot) * train_accuracy_kpd[j] / train_accuracy_kpd.sum() + torch.cos(rot) * torch.cos(rot) * miss_kpd[j] / miss_kpd.sum()
    #plt.ticklabel_format(style='sci',scilimits=(0,0),axis='y')
    plt.xlabel('Weight',fontdict=fontdict)
    plt.ylabel('Measuring Probability',fontdict=fontdict)
    plt.xticks(a)
    plt.subplots_adjust(left=0.16, bottom=0.14, right=0.84, top=0.96,hspace=0.5,wspace=0.5)
    x, = plt.plot(sorted_train_accuracy[0]/sorted_train_accuracy[0].sum(), color='darkorange',dashes=[6,2])
    plt.legend([x],['normalized accuracy'], prop=fontdict1)
    
    plt.bar(list(range(256)), pos, color='royalblue')
    plt.savefig('./img/new8/edge_detection_3PD.pdf',pad_inches=0.5)
    plt.clf()
    pos_accumulate = torch.zeros((256))
    for j in range(256):
        pos_accumulate[j] = pos[0:j].sum()
    oppose = 1.0 - pos_accumulate
    shots = list(range(0, TIMES))
    train_accuracy_ex = torch.zeros((TIMES))
    test_accuracy_ex = torch.zeros((TIMES))
    train_accuracy_var = torch.zeros((TIMES))
    test_accuracy_var = torch.zeros((TIMES))
    for i in range(len(shots)):
        s = shots[i]
        x = torch.pow(oppose, s)
        y = torch.zeros((256))
        for k in range(255):
            y[k] = x[k] - x[k+1]
        y[255] = x[255]
        ex = 0
        ex2 = 0
        for j in range(256):
                ex += y[j] * sorted_train_accuracy[0][j] / 400
                ex2 += y[j] * sorted_train_accuracy[0][j] * sorted_train_accuracy[0][j] /(400 * 400)
        var = ex2 - ex * ex
        train_accuracy_ex[i] = ex
        train_accuracy_var[i] = var
        t_ex = 0
        t_ex2 = 0
        for j in range(256):
                t_ex += y[j] * test_accuracy[sorted_train_accuracy[1][j]] / 112
                t_ex2 += y[j] * test_accuracy[sorted_train_accuracy[1][j]] * test_accuracy[sorted_train_accuracy[1][j]] / (112* 112)
        t_var = t_ex2 - t_ex * t_ex
        test_accuracy_ex[i] = t_ex
        test_accuracy_var[i] = t_var
    # print(train_accuracy_ex)
    # print(test_accuracy_ex)
    # print(train_accuracy_var)
    # print(test_accuracy_var)
    plt.title('shots-expected-accuracy-relationship')
    plt.ylabel('expected accuracy')
    plt.xlabel('shots')
    plt.plot(shots, train_accuracy_ex)
    plt.savefig('./img/3PD.pdf')
    plt.clf()
    for i in range(0, TIMES):
        if train_accuracy_ex[i] >= ACC:
            print("needs {} times of run to get a near-optimal weight".format(i))
            break
    return train_accuracy_ex, test_accuracy_ex , train_accuracy_var, test_accuracy_var
def numerical_simulate_4PD(k=4):
    norm = plt.Normalize(train_accuracy.min() - 800, train_accuracy.max())
    sorted_train_accuracy = train_accuracy.sort(descending=True)
    map_vir = cm.get_cmap(name='OrRd')
    color = map_vir(norm(sorted_train_accuracy[0]))

    search_space = 512 ** k
    train_accuracy_kpd = torch.pow(sorted_train_accuracy[0], k)
    theta = torch.arcsin(torch.sqrt(train_accuracy_kpd.mean()/(search_space)))
    
    

    
    rot = (9) * theta
    miss_kpd :torch.Tensor = search_space - train_accuracy_kpd
    pos = torch.zeros((256))
    for j in range(256):
        pos[j] = torch.sin(rot) * torch.sin(rot) * train_accuracy_kpd[j] / train_accuracy_kpd.sum() + torch.cos(rot) * torch.cos(rot) * miss_kpd[j] / miss_kpd.sum()
    #plt.ticklabel_format(style='sci',scilimits=(0,0),axis='y')
    plt.xlabel('Weight',fontdict=fontdict)
    plt.ylabel('Measuring Probability',fontdict=fontdict)
    plt.xticks(a)
    plt.subplots_adjust(left=0.16, bottom=0.14, right=0.84, top=0.96,hspace=0.5,wspace=0.5)
    x, = plt.plot(sorted_train_accuracy[0]/sorted_train_accuracy[0].sum(),color='darkorange',dashes=[6,2])
    plt.legend([x],['normalized accuracy'], prop=fontdict1)
    plt.bar(list(range(256)), pos, color='royalblue')
    plt.savefig('./img/new8/edge_detection_4PD.pdf',pad_inches=0.5)
    plt.clf()
    pos_accumulate = torch.zeros((256))
    for j in range(256):
        pos_accumulate[j] = pos[0:j].sum()
    oppose = 1.0 - pos_accumulate
    shots = list(range(1, TIMES + 1))
    train_accuracy_ex = torch.zeros((TIMES))
    test_accuracy_ex = torch.zeros((TIMES))
    train_accuracy_var = torch.zeros((TIMES))
    test_accuracy_var = torch.zeros((TIMES))
    for i in range(len(shots)):
        s = shots[i]
        x = torch.pow(oppose, s)
        y = torch.zeros((256))
        for k in range(255):
            y[k] = x[k] - x[k+1]
        y[255] = x[255]
        ex = 0
        ex2 = 0
        for j in range(256):
                ex += y[j] * sorted_train_accuracy[0][j] / 400
                ex2 += y[j] * sorted_train_accuracy[0][j] * sorted_train_accuracy[0][j] /(400 * 400)
        var = ex2 - ex * ex
        train_accuracy_ex[i] = ex
        train_accuracy_var[i] = var
        t_ex = 0
        t_ex2 = 0
        for j in range(256):
                t_ex += y[j] * test_accuracy[sorted_train_accuracy[1][j]] / 112
                t_ex2 += y[j] * test_accuracy[sorted_train_accuracy[1][j]] * test_accuracy[sorted_train_accuracy[1][j]] / (112* 112)
        t_var = t_ex2 - t_ex * t_ex
        test_accuracy_ex[i] = t_ex
        test_accuracy_var[i] = t_var
    # print(train_accuracy_ex)
    # print(test_accuracy_ex)
    # print(train_accuracy_var)
    # print(test_accuracy_var)
    plt.title('shots-expected-accuracy-relationship')
    plt.ylabel('expected accuracy')
    plt.xlabel('shots')
    plt.plot(shots, train_accuracy_ex)
    plt.savefig('./img/4PD.pdf')
    plt.clf()
    print((test_accuracy_ex > ACCX).nonzero()[0])
    for i in range(0, TIMES):
        if train_accuracy_ex[i] >= ACC:
            print("needs {} times of run to get a near-optimal weight".format(i))
            break
    return train_accuracy_ex, test_accuracy_ex , train_accuracy_var, test_accuracy_var
accuracy_distribution()
# train0, test0 = numerical_simulate_initialize()
# train1, test1 = numerical_simulate_1PD()
# train2, test2 = numerical_simulate_2PD()
# train3, test3 = numerical_simulate_3PD()
# train4, test4 = numerical_simulate_4PD()

# plt.xlabel('shots')
# plt.ylabel('accuracy')

# line0, = plt.plot(list(range(1,TIMES + 1)), train0)
# line1, = plt.plot(list(range(1,TIMES + 1)), train1)
# line2, = plt.plot(list(range(1,TIMES + 1)), train2)
# line3, = plt.plot(list(range(1,TIMES + 1)), train3)
# line4, = plt.plot(list(range(1,TIMES + 1)), train4)
# plt.legend((line0, line1, line2, line3, line4), ['URS', '1-PD', '2-PD', '3-PD', '4-PD'])
# plt.savefig('./img/new4/shots_train.pdf')
# plt.clf()


# plt.xlabel('shots')
# plt.ylabel('accuracy')

# line0, = plt.plot(list(range(1,TIMES + 1)), test0)
# line1, = plt.plot(list(range(1,TIMES + 1)), test1)
# line2, = plt.plot(list(range(1,TIMES + 1)), test2)
# line3, = plt.plot(list(range(1,TIMES + 1)), test3)
# line4, = plt.plot(list(range(1,TIMES + 1)), test4)
# plt.legend((line0, line1, line2, line3, line4), ['URS', '1-PD', '2-PD', '3-PD', '4-PD'])
# plt.savefig('./img/new4/shots_test.pdf')
# plt.clf()


train0, test0 , v0, u0= numerical_simulate_initialize()
train1, test1 , v1, u1= numerical_simulate_1PD()


train2, test2 , v2, u2= numerical_simulate_2PD()
train3, test3 , v3, u3= numerical_simulate_3PD()
train4, test4 , v4, u4= numerical_simulate_4PD()

plt.xlabel('measurement budget(shots)',fontdict=fontdict)
plt.ylabel('training accuracy',fontdict=fontdict)

line0, = plt.plot(list(range(1,TIMES + 1)), train0, color=palette(0))
r1 = list(map(lambda x: x[0]-x[1], zip(train0, torch.sqrt(v0))))
r2 = list(map(lambda x: x[0]+x[1], zip(train0, torch.sqrt(v0))))
plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(0), alpha=0.2)

line1, = plt.plot(list(range(1,TIMES + 1)), train1, color=palette(1))
r1 = list(map(lambda x: x[0]-x[1], zip(train1, torch.sqrt(v1))))
r2 = list(map(lambda x: x[0]+x[1], zip(train1, torch.sqrt(v1))))
plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(1), alpha=0.2)


# line2, = plt.plot(list(range(1,TIMES + 1)), train2, color=palette(2))
# r1 = list(map(lambda x: x[0]-x[1], zip(train2, torch.sqrt(v2))))
# r2 = list(map(lambda x: x[0]+x[1], zip(train2, torch.sqrt(v2))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(2), alpha=0.2)

# line3, = plt.plot(list(range(1,TIMES + 1)), train3, color=palette(3))
# r1 = list(map(lambda x: x[0]-x[1], zip(train3, torch.sqrt(v3))))
# r2 = list(map(lambda x: x[0]+x[1], zip(train3, torch.sqrt(v3))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(3), alpha=0.2)

line4, = plt.plot(list(range(1,TIMES + 1)), train4, color=palette(4))
r1 = list(map(lambda x: x[0]-x[1], zip(train4, torch.sqrt(v4))))
r2 = list(map(lambda x: x[0]+x[1], zip(train4, torch.sqrt(v4))))
plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(4), alpha=0.2)
plt.subplots_adjust(left=0.16, bottom=0.14, right=0.84, top=0.96,hspace=0.5,wspace=0.5)
plt.legend((line0, line1, line4), ['URS', '1-PD', '4-PD'],prop=fontdict1)
plt.savefig('./img/new8/shots_train.pdf')
plt.clf()


plt.xlabel('measurement budget(shots)',fontdict=fontdict)
plt.ylabel('test accuracy',fontdict=fontdict)

line0, = plt.plot(list(range(1,TIMES + 1)), test0,color=palette(0))
r1 = list(map(lambda x: x[0]-x[1], zip(test0, torch.sqrt(u0))))
r2 = list(map(lambda x: x[0]+x[1], zip(test0, torch.sqrt(u0))))
plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(0), alpha=0.2)


line1, = plt.plot(list(range(1,TIMES + 1)), test1,color=palette(1))
r1 = list(map(lambda x: x[0]-x[1], zip(test1, torch.sqrt(u1))))
r2 = list(map(lambda x: x[0]+x[1], zip(test1, torch.sqrt(u1))))
plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(1), alpha=0.2)

# line2, = plt.plot(list(range(1,TIMES + 1)), test2,color=palette(2))
# r1 = list(map(lambda x: x[0]-x[1], zip(test2, torch.sqrt(u2))))
# r2 = list(map(lambda x: x[0]+x[1], zip(test2, torch.sqrt(u2))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(2), alpha=0.2)

# line3, = plt.plot(list(range(1,TIMES + 1)), test3,color=palette(3))
# r1 = list(map(lambda x: x[0]-x[1], zip(test3, torch.sqrt(u3))))
# r2 = list(map(lambda x: x[0]+x[1], zip(test3, torch.sqrt(u3))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(3), alpha=0.2)

line4, = plt.plot(list(range(1,TIMES + 1)), test4,color=palette(4))
r1 = list(map(lambda x: x[0]-x[1], zip(test4, torch.sqrt(u4))))
r2 = list(map(lambda x: x[0]+x[1], zip(test4, torch.sqrt(u4))))
plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(4), alpha=0.2)
plt.legend((line0, line1, line4), ['URS', '1-PD', '4-PD'],prop=fontdict1)
plt.subplots_adjust(left=0.16, bottom=0.14, right=0.84, top=0.96,hspace=0.5,wspace=0.5)
plt.savefig('./img/new8/shots_test.pdf')
plt.clf()

plt.xlabel('number of shots',fontdict=fontdict)
plt.ylabel('std',fontdict=fontdict)

line0, = plt.plot(list(range(1,TIMES + 1)), v0.sqrt(),color=palette(0))
# r1 = list(map(lambda x: x[0]-x[1], zip(test0, torch.sqrt(u0))))
# r2 = list(map(lambda x: x[0]+x[1], zip(test0, torch.sqrt(u0))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(0), alpha=0.2)

line1, = plt.plot(list(range(1,TIMES + 1)), v1.sqrt(),color=palette(1))
# r1 = list(map(lambda x: x[0]-x[1], zip(test1, torch.sqrt(u1))))
# r2 = list(map(lambda x: x[0]+x[1], zip(test1, torch.sqrt(u1))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(1), alpha=0.2)

line2, = plt.plot(list(range(1,TIMES + 1)), v2.sqrt(),color=palette(2))
# r1 = list(map(lambda x: x[0]-x[1], zip(test2, torch.sqrt(u2))))
# r2 = list(map(lambda x: x[0]+x[1], zip(test2, torch.sqrt(u2))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(2), alpha=0.2)

line3, = plt.plot(list(range(1,TIMES + 1)), v3.sqrt(),color=palette(3))
# r1 = list(map(lambda x: x[0]-x[1], zip(test3, torch.sqrt(u3))))
# r2 = list(map(lambda x: x[0]+x[1], zip(test3, torch.sqrt(u3))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(3), alpha=0.2)

line4, = plt.plot(list(range(1,TIMES + 1)), v4.sqrt(),color=palette(4))
# r1 = list(map(lambda x: x[0]-x[1], zip(test4, torch.sqrt(u4))))
# r2 = list(map(lambda x: x[0]+x[1], zip(test4, torch.sqrt(u4))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(4), alpha=0.2)
plt.legend((line0, line1, line2, line3, line4), ['URS', '1-PD', '2-PD', '3-PD', '4-PD'],prop=fontdict1)
plt.subplots_adjust(left=0.16, bottom=0.14, right=0.84, top=0.96,hspace=0.5,wspace=0.5)
plt.savefig('./img/new8/shots_train_std.pdf')
plt.clf()

plt.xlabel('number of shots',fontdict=fontdict)
plt.ylabel('std',fontdict=fontdict)

line0, = plt.plot(list(range(1,TIMES + 1)), u0.sqrt(),color=palette(0))
# r1 = list(map(lambda x: x[0]-x[1], zip(test0, torch.sqrt(u0))))
# r2 = list(map(lambda x: x[0]+x[1], zip(test0, torch.sqrt(u0))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(0), alpha=0.2)

line1, = plt.plot(list(range(1,TIMES + 1)), u1.sqrt(),color=palette(1))
# r1 = list(map(lambda x: x[0]-x[1], zip(test1, torch.sqrt(u1))))
# r2 = list(map(lambda x: x[0]+x[1], zip(test1, torch.sqrt(u1))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(1), alpha=0.2)

line2, = plt.plot(list(range(1,TIMES + 1)), u2.sqrt(),color=palette(2))
# r1 = list(map(lambda x: x[0]-x[1], zip(test2, torch.sqrt(u2))))
# r2 = list(map(lambda x: x[0]+x[1], zip(test2, torch.sqrt(u2))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(2), alpha=0.2)

line3, = plt.plot(list(range(1,TIMES + 1)), u3.sqrt(),color=palette(3))
# r1 = list(map(lambda x: x[0]-x[1], zip(test3, torch.sqrt(u3))))
# r2 = list(map(lambda x: x[0]+x[1], zip(test3, torch.sqrt(u3))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(3), alpha=0.2)

line4, = plt.plot(list(range(1,TIMES + 1)), u4.sqrt(),color=palette(4))
# r1 = list(map(lambda x: x[0]-x[1], zip(test4, torch.sqrt(u4))))
# r2 = list(map(lambda x: x[0]+x[1], zip(test4, torch.sqrt(u4))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(4), alpha=0.2)
plt.legend((line0, line1, line2, line3, line4), ['URS', '1-PD', '2-PD', '3-PD', '4-PD'],prop=fontdict1)
plt.subplots_adjust(left=0.16, bottom=0.14, right=0.84, top=0.96,hspace=0.5,wspace=0.5)
plt.savefig('./img/new8/shots_test_std.pdf')
plt.clf()












