from cProfile import label
from genericpath import exists
import itertools
import posix
from select import select
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
import os
from tqdm import tqdm
from random import random, seed, shuffle
from matplotlib import cm
from matplotlib.collections import LineCollection
import matplotlib.pyplot as plt
from matplotlib import pyplot
plt.style.use('seaborn-white')
palette = pyplot.get_cmap('Set1')
#torch.set_printoptions(profile="full")
from brokenaxes import brokenaxes
from matplotlib import rcParams
fontdict = {'family':'Times New Roman', 'size': 40,'weight': 'bold'}
fontdict1 = {'family':'Times New Roman', 'size':40,'weight': 'bold'}
plt.rc('font',family='Times New Roman', size=40)
fig_size = plt.rcParams["figure.figsize"]
plt.rcParams["axes.labelweight"] = "bold"
fig_size[0] = 12
fig_size[1] = 8
plt.rcParams["figure.figsize"] = fig_size
a = [0, 8192, 16384, 24576, 32768]
WN = 2**20
TIMES = 32768
ACCX = 0.7671
shots = list(range(1,TIMES+1))
train_accuracy = []
test_accuracy = []
for i in range(32):
    train_ac = torch.load('./accuracy/2153/temp_train{}_{}.pt'.format(i, 32))
    test_ac = torch.load('./accuracy/2153/temp_test{}_{}.pt'.format(i, 32))
    train_accuracy.append(train_ac[int(i * WN / 32): int((i+1)*WN/32)])
    test_accuracy.append(test_ac[int(i * WN / 32): int((i+1)*WN/32)])
train_accuracy = torch.cat(train_accuracy, dim=0) / 79
test_accuracy = torch.cat(test_accuracy, dim=0) / 46
sorted_train_accuracy = train_accuracy.sort(descending=True)

def numerical_simulate_initialize(k=1):
    sampler = list(range(0, 2**20, 32))
    search_space = 1
    train_accuracy_kpd = torch.pow(sorted_train_accuracy[0], k)
    theta = torch.arcsin(torch.sqrt(train_accuracy_kpd.mean()/(search_space)))
    rot = (1) * theta
    miss_kpd :torch.Tensor = search_space - train_accuracy_kpd
    
    pos = torch.sin(rot) * torch.sin(rot) * train_accuracy_kpd / train_accuracy_kpd.sum() + torch.cos(rot) * torch.cos(rot) * miss_kpd / miss_kpd.sum()
    posx = torch.sin(rot) * torch.sin(rot) * train_accuracy_kpd[sampler] / train_accuracy_kpd[sampler].sum() + torch.cos(rot) * torch.cos(rot) * miss_kpd[sampler] / miss_kpd[sampler].sum()
    
    
    norm = plt.Normalize(sorted_train_accuracy[0][sampler].min(), sorted_train_accuracy[0][sampler].max())
    map_vir = cm.get_cmap(name='OrRd')
    color = map_vir(norm(sorted_train_accuracy[0][sampler]))
    map_vir1 = cm.get_cmap(name='YlGn')
    color1 = map_vir1(norm(sorted_train_accuracy[0][sampler]))
    plt.ticklabel_format(style='sci',scilimits=(0,0),axis='y')
    plt.xlabel('Weight',fontdict=fontdict)
    plt.ylabel('Measuring Probability',fontdict=fontdict)
    plt.ylim(0, 1.2e-4)

    plt.subplots_adjust(left=0.16, bottom=0.14, right=0.84, top=0.95,hspace=0.5,wspace=0.5)
    plt.xticks(a)
    x,= plt.plot(sorted_train_accuracy[0][sampler]/ sorted_train_accuracy[0][sampler].sum(),color='darkorange',dashes=[6,2])
    
    plt.legend([x],['normalized accuracy'], prop=fontdict1)
    plt.bar(list(range(32768)), posx, color='royalblue')
    plt.savefig('./img/new10/mnist_initial_state.pdf')
    plt.clf()
    pos_accumulate = torch.zeros((WN))
    pos_accumulate[1:WN] = pos.cumsum(dim=0)[0:WN-1]
    oppose = 1.0 - pos_accumulate
    
    train_accuracy_ex = torch.zeros((TIMES))
    test_accuracy_ex = torch.zeros((TIMES))
    train_accuracy_var = torch.zeros((TIMES))
    test_accuracy_var = torch.zeros((TIMES))
    for i in tqdm(range(len(shots)), total=len(shots)):
        s = shots[i]
        x = torch.pow(oppose, s)
        y = torch.zeros((WN))
        y[0:WN-1] = x.diff(dim=0)
        y = -y
        
        ex = (y * sorted_train_accuracy[0]).sum()
        ex2 = (y * sorted_train_accuracy[0] * sorted_train_accuracy[0]).sum()
        
        var = ex2 - ex * ex
        train_accuracy_ex[i] = ex
        train_accuracy_var[i] = var
        
        t_ex = (y * test_accuracy[sorted_train_accuracy[1]]).sum()
        t_ex2 = (y * test_accuracy[sorted_train_accuracy[1]] * test_accuracy[sorted_train_accuracy[1]]).sum()
        
        t_var = t_ex2 - t_ex * t_ex
        test_accuracy_ex[i] = t_ex
        test_accuracy_var[i] = t_var
   
    plt.title('shots-expected-accuracy-relationship')
    plt.ylabel('expected accuracy')
    plt.xlabel('shots')
    plt.plot(shots, test_accuracy_ex)
    plt.savefig('./img/uniform_sampling.pdf')
    plt.clf()
    print(test_accuracy_ex[0:8000])
    print((test_accuracy_ex > ACCX).nonzero()[0])
    return train_accuracy_ex, test_accuracy_ex, train_accuracy_var, test_accuracy_var

    
def numerical_simulate_1PD(k=1):
    sampler = list(range(0, 2**20, 32))
    search_space = 1
    train_accuracy_kpd = torch.pow(sorted_train_accuracy[0], k)
    theta = torch.arcsin(torch.sqrt(train_accuracy_kpd.mean()/(search_space)))

    
    
    rot = (3) * theta
    miss_kpd :torch.Tensor = search_space - train_accuracy_kpd
    
    pos = torch.sin(rot) * torch.sin(rot) * train_accuracy_kpd / train_accuracy_kpd.sum() + torch.cos(rot) * torch.cos(rot) * miss_kpd / miss_kpd.sum()
    posx = torch.sin(rot) * torch.sin(rot) * train_accuracy_kpd[sampler] / train_accuracy_kpd[sampler].sum() + torch.cos(rot) * torch.cos(rot) * miss_kpd[sampler] / miss_kpd[sampler].sum()
    
    
    norm = plt.Normalize(sorted_train_accuracy[0][sampler].min(), sorted_train_accuracy[0][sampler].max())
    map_vir = cm.get_cmap(name='OrRd')
    color = map_vir(norm(sorted_train_accuracy[0][sampler]))
    map_vir1 = cm.get_cmap(name='YlGn')
    color1 = map_vir1(norm(sorted_train_accuracy[0][sampler]))
    

    # plt.title('MNIST Weight Sampling Possibility Distribution')
    # plt.xlabel('Weight Sorted by Accuracy')
    # plt.ylabel('Weight Sampling Possibility')
    # plt.bar(list(range(32768)), pos[sampler], color='royalblue')
    # plt.savefig('./img/mnist_1PD.pdf')
    # plt.clf()
    plt.ticklabel_format(style='sci',scilimits=(0,0),axis='y')
    plt.xlabel('Weight',fontdict=fontdict)
    plt.ylabel('Measuring Probability',fontdict=fontdict)
    plt.ylim(0, 1.2e-4)
    plt.subplots_adjust(left=0.16, bottom=0.14, right=0.84, top=0.95,hspace=0.5,wspace=0.5)
    plt.xticks(a)
    x, = plt.plot(sorted_train_accuracy[0][sampler]/ sorted_train_accuracy[0][sampler].sum(), color='darkorange',dashes=[6,2])
    plt.legend([x],['normalized accuracy'], prop=fontdict1)
    plt.bar(list(range(32768)), posx, color='royalblue')
    plt.savefig('./img/new10/mnist_1PD.pdf')
    plt.clf()
    pos_accumulate = torch.zeros((WN))
    pos_accumulate[1:WN] = pos.cumsum(dim=0)[0:WN-1]
    oppose = 1.0 - pos_accumulate
    
    train_accuracy_ex = torch.zeros((TIMES))
    test_accuracy_ex = torch.zeros((TIMES))
    train_accuracy_var = torch.zeros((TIMES))
    test_accuracy_var = torch.zeros((TIMES))
    for i in tqdm(range(len(shots)), total=len(shots)):
        s = shots[i]
        x = torch.pow(oppose, s)
        y = torch.zeros((WN))
        y[0:WN-1] = x.diff(dim=0)
        y = -y
        
        ex = (y * sorted_train_accuracy[0]).sum()
        ex2 = (y * sorted_train_accuracy[0] * sorted_train_accuracy[0]).sum()
        
        var = ex2 - ex * ex
        train_accuracy_ex[i] = ex
        train_accuracy_var[i] = var
        
        t_ex = (y * test_accuracy[sorted_train_accuracy[1]]).sum()
        t_ex2 = (y * test_accuracy[sorted_train_accuracy[1]] * test_accuracy[sorted_train_accuracy[1]]).sum()
        
        t_var = t_ex2 - t_ex * t_ex
        test_accuracy_ex[i] = t_ex
        test_accuracy_var[i] = t_var
    
    plt.title('shots-expected-accuracy-relationship')
    plt.ylabel('expected accuracy')
    plt.xlabel('shots')
    plt.plot(shots, train_accuracy_ex)
    plt.savefig('./img/1PD.pdf')
    plt.clf()
    print((test_accuracy_ex > ACCX).nonzero()[0])
    print(train_accuracy_ex)
    return train_accuracy_ex, test_accuracy_ex, train_accuracy_var, test_accuracy_var
def numerical_simulate_2PD(k=2):
    sampler = list(range(0, 2**20, 32))
    search_space = 1
    train_accuracy_kpd = torch.pow(sorted_train_accuracy[0], k)
    theta = torch.arcsin(torch.sqrt(train_accuracy_kpd.mean()/(search_space)))

    
    
    
    rot = (5) * theta
    miss_kpd :torch.Tensor = search_space - train_accuracy_kpd
    
    pos = torch.sin(rot) * torch.sin(rot) * train_accuracy_kpd / train_accuracy_kpd.sum() + torch.cos(rot) * torch.cos(rot) * miss_kpd / miss_kpd.sum()
    posx = torch.sin(rot) * torch.sin(rot) * train_accuracy_kpd[sampler] / train_accuracy_kpd[sampler].sum() + torch.cos(rot) * torch.cos(rot) * miss_kpd[sampler] / miss_kpd[sampler].sum()
    
    
    norm = plt.Normalize(sorted_train_accuracy[0][sampler].min(), sorted_train_accuracy[0][sampler].max())
    map_vir = cm.get_cmap(name='OrRd')
    color = map_vir(norm(sorted_train_accuracy[0][sampler]))

    map_vir1 = cm.get_cmap(name='YlGn')
    color1 = map_vir1(norm(sorted_train_accuracy[0][sampler]))

    # plt.title('MNIST Weight Sampling Possibility Distribution')
    # plt.xlabel('Weight Sorted by Accuracy')
    # plt.ylabel('Weight Sampling Possibility')
    # plt.bar(list(range(32768)), pos[sampler], color='royalblue')
    # plt.savefig('./img/mnist_2PD.pdf')
    # plt.clf()
    plt.ticklabel_format(style='sci',scilimits=(0,0),axis='y')
    plt.xlabel('Weight',fontdict=fontdict)
    plt.ylabel('Measuring Probability',fontdict=fontdict)
    #plt.ylim(0, 5.5e-4)
    plt.subplots_adjust(left=0.16, bottom=0.14, right=0.84, top=0.95,hspace=0.5,wspace=0.5)
    plt.xticks(a)
    x,= plt.plot(sorted_train_accuracy[0][sampler]/ sorted_train_accuracy[0][sampler].sum(),color='darkorange',dashes=[6,2])
    plt.legend([x],['normalized accuracy'], prop=fontdict1)
    plt.bar(list(range(32768)), posx, color='royalblue')
    plt.savefig('./img/new10/mnist_2PD.pdf')
    plt.clf()
    pos_accumulate = torch.zeros((WN))
    pos_accumulate[1:WN] = pos.cumsum(dim=0)[0:WN-1]
    oppose = 1.0 - pos_accumulate
    
    train_accuracy_ex = torch.zeros((TIMES))
    test_accuracy_ex = torch.zeros((TIMES))
    train_accuracy_var = torch.zeros((TIMES))
    test_accuracy_var = torch.zeros((TIMES))
    for i in tqdm(range(len(shots)), total=len(shots)):
        s = shots[i]
        x = torch.pow(oppose, s)
        y = torch.zeros((WN))
        y[0:WN-1] = x.diff(dim=0)
        y = -y
        
        ex = (y * sorted_train_accuracy[0]).sum()
        ex2 = (y * sorted_train_accuracy[0] * sorted_train_accuracy[0]).sum() 
        var = ex2 - ex * ex
        train_accuracy_ex[i] = ex
        train_accuracy_var[i] = var
        
        t_ex = (y * test_accuracy[sorted_train_accuracy[1]]).sum()
        t_ex2 = (y * test_accuracy[sorted_train_accuracy[1]] * test_accuracy[sorted_train_accuracy[1]]).sum() 
        
        t_var = t_ex2 - t_ex * t_ex
        test_accuracy_ex[i] = t_ex
        test_accuracy_var[i] = t_var
    
    plt.title('shots-expected-accuracy-relationship')
    plt.ylabel('expected accuracy')
    plt.xlabel('shots')
    plt.plot(shots, train_accuracy_ex)
    plt.savefig('./img/2PD.pdf')
    plt.clf()
    return train_accuracy_ex, test_accuracy_ex, train_accuracy_var, test_accuracy_var
def numerical_simulate_3PD(k=3):
    sampler = list(range(0, 2**20, 32))
    search_space = 1
    train_accuracy_kpd = torch.pow(sorted_train_accuracy[0], k)
    theta = torch.arcsin(torch.sqrt(train_accuracy_kpd.mean()/(search_space)))

    
    
    
    rot = (7) * theta
    miss_kpd :torch.Tensor = search_space - train_accuracy_kpd
    
    pos = torch.sin(rot) * torch.sin(rot) * train_accuracy_kpd / train_accuracy_kpd.sum() + torch.cos(rot) * torch.cos(rot) * miss_kpd / miss_kpd.sum()
    posx = torch.sin(rot) * torch.sin(rot) * train_accuracy_kpd[sampler] / train_accuracy_kpd[sampler].sum() + torch.cos(rot) * torch.cos(rot) * miss_kpd[sampler] / miss_kpd[sampler].sum()
    
    norm = plt.Normalize(sorted_train_accuracy[0][sampler].min(), sorted_train_accuracy[0][sampler].max())
    map_vir = cm.get_cmap(name='OrRd')
    color = map_vir(norm(sorted_train_accuracy[0][sampler]))

    map_vir1 = cm.get_cmap(name='YlGn')
    color1 = map_vir1(norm(sorted_train_accuracy[0][sampler]))
    
    # plt.title('MNIST Weight Sampling Possibility Distribution')
    # plt.xlabel('Weight Sorted by Accuracy')
    # plt.ylabel('Weight Sampling Possibility')
    # plt.bar(list(range(32768)), pos[sampler], color='royalblue')
    # plt.savefig('./img/mnist_3PD.pdf')
    # plt.clf()
    plt.ticklabel_format(style='sci',scilimits=(0,0),axis='y')
    plt.xlabel('Weight',fontdict=fontdict)
    plt.ylabel('Measuring Probability',fontdict=fontdict)
    #plt.ylim(0, 5.5e-4)
    plt.subplots_adjust(left=0.16, bottom=0.14, right=0.84, top=0.95,hspace=0.5,wspace=0.5)
    plt.xticks(a)
    x, = plt.plot(sorted_train_accuracy[0][sampler]/ sorted_train_accuracy[0][sampler].sum(), color='darkorange',dashes=[6,2])
    plt.legend([x],['normalized accuracy'], prop=fontdict1)
    plt.bar(list(range(32768)), posx, color='royalblue')
    plt.savefig('./img/new10/mnist_3PD.pdf')
    plt.clf()
    pos_accumulate = torch.zeros((WN))
    pos_accumulate[1:WN] = pos.cumsum(dim=0)[0:WN-1]
    oppose = 1.0 - pos_accumulate
    
    train_accuracy_ex = torch.zeros((TIMES))
    test_accuracy_ex = torch.zeros((TIMES))
    train_accuracy_var = torch.zeros((TIMES))
    test_accuracy_var = torch.zeros((TIMES))
    for i in tqdm(range(len(shots)), total=len(shots)):
        s = shots[i]
        x = torch.pow(oppose, s)
        y = torch.zeros((WN))
        y[0:WN-1] = x.diff(dim=0)
        y = -y
        
        ex = (y * sorted_train_accuracy[0]).sum() 
        ex2 = (y * sorted_train_accuracy[0] * sorted_train_accuracy[0]).sum() 
        
        var = ex2 - ex * ex
        train_accuracy_ex[i] = ex
        train_accuracy_var[i] = var
        
        t_ex = (y * test_accuracy[sorted_train_accuracy[1]]).sum() 
        t_ex2 = (y * test_accuracy[sorted_train_accuracy[1]] * test_accuracy[sorted_train_accuracy[1]]).sum() 
        
        t_var = t_ex2 - t_ex * t_ex
        test_accuracy_ex[i] = t_ex
        test_accuracy_var[i] = t_var
    
    plt.title('shots-expected-accuracy-relationship')
    plt.ylabel('expected accuracy')
    plt.xlabel('shots')
    plt.plot(shots, train_accuracy_ex)
    plt.savefig('./img/3PD.pdf')
    plt.clf()
    return train_accuracy_ex, test_accuracy_ex, train_accuracy_var, test_accuracy_var
def numerical_simulate_4PD(k=4):
    sampler = list(range(0, 2**20, 32))
    search_space = 1
    train_accuracy_kpd = torch.pow(sorted_train_accuracy[0], k)
    theta = torch.arcsin(torch.sqrt(train_accuracy_kpd.mean()/(search_space)))

    
    
   
    rot = (9) * theta
    miss_kpd :torch.Tensor = search_space - train_accuracy_kpd
    
    
    pos = torch.sin(rot) * torch.sin(rot) * train_accuracy_kpd / train_accuracy_kpd.sum() + torch.cos(rot) * torch.cos(rot) * miss_kpd / miss_kpd.sum()
    posx = torch.sin(rot) * torch.sin(rot) * train_accuracy_kpd[sampler] / train_accuracy_kpd[sampler].sum() + torch.cos(rot) * torch.cos(rot) * miss_kpd[sampler] / miss_kpd[sampler].sum()
    norm = plt.Normalize(sorted_train_accuracy[0][sampler].min(), sorted_train_accuracy[0][sampler].max())
    map_vir = cm.get_cmap(name='OrRd')
    color = map_vir(norm(sorted_train_accuracy[0][sampler]))
    map_vir1 = cm.get_cmap(name='YlGn')
    color1 = map_vir1(norm(sorted_train_accuracy[0][sampler]))
    # plt.title('MNIST Weight Sampling Possibility Distribution')
    # plt.xlabel('Weight Sorted by Accuracy')
    # plt.ylabel('Weight Sampling Possibility')
    # plt.bar(list(range(32768)), pos[sampler], color='royalblue')
    # plt.savefig('./img/mnist_4PD.pdf')
    # plt.clf()
    plt.ticklabel_format(style='sci',scilimits=(0,0),axis='y')
    plt.xlabel('Weight',fontdict=fontdict)
    plt.ylabel('Measuring Probability',fontdict=fontdict)
    #plt.ylim(0, 5.5e-4)
    plt.subplots_adjust(left=0.16, bottom=0.14, right=0.84, top=0.95,hspace=0.5,wspace=0.5)
    plt.xticks(a)
    x,= plt.plot(sorted_train_accuracy[0][sampler]/ sorted_train_accuracy[0][sampler].sum(),color='darkorange',dashes=[6,2])
    plt.legend([x],['normalized accuracy'], prop=fontdict1)
    plt.bar(list(range(32768)), posx, color='royalblue')
    plt.savefig('./img/new10/mnist_4PD.pdf')
    plt.clf()
    pos_accumulate = torch.zeros((WN))
    pos_accumulate[1:WN] = pos.cumsum(dim=0)[0:WN-1]
    oppose = 1.0 - pos_accumulate
    
    train_accuracy_ex = torch.zeros((TIMES))
    test_accuracy_ex = torch.zeros((TIMES))
    train_accuracy_var = torch.zeros((TIMES))
    test_accuracy_var = torch.zeros((TIMES))
    for i in tqdm(range(len(shots)), total=len(shots)):
        s = shots[i]
        x = torch.pow(oppose, s)
        y = torch.zeros((WN))
        y[0:WN-1] = x.diff(dim=0)
        y = -y
        
        ex = (y * sorted_train_accuracy[0]).sum() 
        ex2 = (y * sorted_train_accuracy[0] * sorted_train_accuracy[0]).sum()
        
        var = ex2 - ex * ex
        train_accuracy_ex[i] = ex
        train_accuracy_var[i] = var
        
        t_ex = (y * test_accuracy[sorted_train_accuracy[1]]).sum() 
        t_ex2 = (y * test_accuracy[sorted_train_accuracy[1]] * test_accuracy[sorted_train_accuracy[1]]).sum()
        
        t_var = t_ex2 - t_ex * t_ex
        test_accuracy_ex[i] = t_ex
        test_accuracy_var[i] = t_var
    
    plt.title('shots-expected-accuracy-relationship')
    plt.ylabel('expected accuracy')
    plt.xlabel('shots')
    plt.plot(shots, train_accuracy_ex)
    plt.savefig('./img/4PD.pdf')
    plt.clf()
    print((test_accuracy_ex > ACCX).nonzero()[0])
    return train_accuracy_ex, test_accuracy_ex, train_accuracy_var, test_accuracy_var
train0, test0 , v0, u0= numerical_simulate_initialize()

train1, test1 , v1, u1= numerical_simulate_1PD()


train2, test2 , v2, u2= numerical_simulate_2PD()
train3, test3 , v3, u3= numerical_simulate_3PD()
train4, test4 , v4, u4= numerical_simulate_4PD()

print(v0)
print(v4)
plt.xlabel('measurement budget(shots)',fontdict=fontdict)
plt.ylabel('training accuracy',fontdict=fontdict)
plt.ylim(0.5, 0.9)
plt.subplots_adjust(left=0.16, bottom=0.14, right=0.84, top=0.96,hspace=0.5,wspace=0.5)
line0, = plt.plot(list(range(1,TIMES + 1)), train0, color=palette(0))
r1 = list(map(lambda x: x[0]-x[1], zip(train0, torch.sqrt(v0))))
r2 = list(map(lambda x: x[0]+x[1], zip(train0, torch.sqrt(v0))))
plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(0), alpha=0.2)

line1, = plt.plot(list(range(1,TIMES + 1)), train1, color=palette(1))
r1 = list(map(lambda x: x[0]-x[1], zip(train1, torch.sqrt(v1))))
r2 = list(map(lambda x: x[0]+x[1], zip(train1, torch.sqrt(v1))))
plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(1), alpha=0.2)


# line2, = plt.plot(list(range(1,TIMES + 1)), train2, color=palette(2))
# r1 = list(map(lambda x: x[0]-x[1], zip(train2, torch.sqrt(v2))))
# r2 = list(map(lambda x: x[0]+x[1], zip(train2, torch.sqrt(v2))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(2), alpha=0.2)

# line3, = plt.plot(list(range(1,TIMES + 1)), train3, color=palette(3))
# r1 = list(map(lambda x: x[0]-x[1], zip(train3, torch.sqrt(v3))))
# r2 = list(map(lambda x: x[0]+x[1], zip(train3, torch.sqrt(v3))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(3), alpha=0.2)

line4, = plt.plot(list(range(1,TIMES + 1)), train4, color=palette(4))
r1 = list(map(lambda x: x[0]-x[1], zip(train4, torch.sqrt(v4))))
r2 = list(map(lambda x: x[0]+x[1], zip(train4, torch.sqrt(v4))))
plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(4), alpha=0.2)
plt.legend((line0, line1,  line4), ['URS', '1-PD',  '4-PD'],prop=fontdict1)
plt.savefig('./img/new10/shots_train.pdf')
plt.clf()


plt.xlabel('measurement budget(shots)',fontdict=fontdict)
plt.ylabel('test accuracy',fontdict=fontdict)
plt.ylim(0.5, 0.9)
plt.subplots_adjust(left=0.16, bottom=0.14, right=0.84, top=0.96,hspace=0.5,wspace=0.5)
line0, = plt.plot(list(range(1,TIMES + 1)), test0,color=palette(0))
r1 = list(map(lambda x: x[0]-x[1], zip(test0, torch.sqrt(u0))))
r2 = list(map(lambda x: x[0]+x[1], zip(test0, torch.sqrt(u0))))
plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(0), alpha=0.2)


line1, = plt.plot(list(range(1,TIMES + 1)), test1,color=palette(1))
r1 = list(map(lambda x: x[0]-x[1], zip(test1, torch.sqrt(u1))))
r2 = list(map(lambda x: x[0]+x[1], zip(test1, torch.sqrt(u1))))
plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(1), alpha=0.2)

# line2, = plt.plot(list(range(1,TIMES + 1)), test2,color=palette(2))
# r1 = list(map(lambda x: x[0]-x[1], zip(test2, torch.sqrt(u2))))
# r2 = list(map(lambda x: x[0]+x[1], zip(test2, torch.sqrt(u2))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(2), alpha=0.2)

# line3, = plt.plot(list(range(1,TIMES + 1)), test3,color=palette(3))
# r1 = list(map(lambda x: x[0]-x[1], zip(test3, torch.sqrt(u3))))
# r2 = list(map(lambda x: x[0]+x[1], zip(test3, torch.sqrt(u3))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(3), alpha=0.2)

line4, = plt.plot(list(range(1,TIMES + 1)), test4,color=palette(4))
r1 = list(map(lambda x: x[0]-x[1], zip(test4, torch.sqrt(u4))))
r2 = list(map(lambda x: x[0]+x[1], zip(test4, torch.sqrt(u4))))
plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(4), alpha=0.2)
plt.legend((line0, line1,  line4), ['URS', '1-PD',  '4-PD'],prop=fontdict1)
plt.savefig('./img/new10/shots_test.pdf')
plt.clf()


plt.xlabel('number of shots',fontdict=fontdict)
plt.ylabel('std',fontdict=fontdict)
plt.subplots_adjust(left=0.16, bottom=0.14, right=0.84, top=0.96,hspace=0.5,wspace=0.5)
#plt.ylim(0, 0.05)
line0, = plt.plot(list(range(1,TIMES + 1)), v0.sqrt(),color=palette(0))
# r1 = list(map(lambda x: x[0]-x[1], zip(test0, torch.sqrt(u0))))
# r2 = list(map(lambda x: x[0]+x[1], zip(test0, torch.sqrt(u0))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(0), alpha=0.2)

line1, = plt.plot(list(range(1,TIMES + 1)), v1.sqrt(),color=palette(1))
# r1 = list(map(lambda x: x[0]-x[1], zip(test1, torch.sqrt(u1))))
# r2 = list(map(lambda x: x[0]+x[1], zip(test1, torch.sqrt(u1))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(1), alpha=0.2)

line2, = plt.plot(list(range(1,TIMES + 1)), v2.sqrt(),color=palette(2))
# r1 = list(map(lambda x: x[0]-x[1], zip(test2, torch.sqrt(u2))))
# r2 = list(map(lambda x: x[0]+x[1], zip(test2, torch.sqrt(u2))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(2), alpha=0.2)

line3, = plt.plot(list(range(1,TIMES + 1)), v3.sqrt(),color=palette(3))
# r1 = list(map(lambda x: x[0]-x[1], zip(test3, torch.sqrt(u3))))
# r2 = list(map(lambda x: x[0]+x[1], zip(test3, torch.sqrt(u3))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(3), alpha=0.2)

line4, = plt.plot(list(range(1,TIMES + 1)), v4.sqrt(),color=palette(4))
# r1 = list(map(lambda x: x[0]-x[1], zip(test4, torch.sqrt(u4))))
# r2 = list(map(lambda x: x[0]+x[1], zip(test4, torch.sqrt(u4))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(4), alpha=0.2)
plt.legend((line0, line1, line2, line3, line4), ['URS', '1-PD', '2-PD', '3-PD', '4-PD'],prop=fontdict1)
plt.savefig('./img/new10/shots_train_std.pdf')
plt.clf()

plt.xlabel('number of shots',fontdict=fontdict)
plt.ylabel('std',fontdict=fontdict)
plt.subplots_adjust(left=0.16, bottom=0.14, right=0.84, top=0.96,hspace=0.5,wspace=0.5)
#plt.ylim(0, 0.05)
line0, = plt.plot(list(range(1,TIMES + 1)), u0.sqrt(),color=palette(0))
# r1 = list(map(lambda x: x[0]-x[1], zip(test0, torch.sqrt(u0))))
# r2 = list(map(lambda x: x[0]+x[1], zip(test0, torch.sqrt(u0))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(0), alpha=0.2)

line1, = plt.plot(list(range(1,TIMES + 1)), u1.sqrt(),color=palette(1))
# r1 = list(map(lambda x: x[0]-x[1], zip(test1, torch.sqrt(u1))))
# r2 = list(map(lambda x: x[0]+x[1], zip(test1, torch.sqrt(u1))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(1), alpha=0.2)

line2, = plt.plot(list(range(1,TIMES + 1)), u2.sqrt(),color=palette(2))
# r1 = list(map(lambda x: x[0]-x[1], zip(test2, torch.sqrt(u2))))
# r2 = list(map(lambda x: x[0]+x[1], zip(test2, torch.sqrt(u2))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(2), alpha=0.2)

line3, = plt.plot(list(range(1,TIMES + 1)), u3.sqrt(),color=palette(3))
# r1 = list(map(lambda x: x[0]-x[1], zip(test3, torch.sqrt(u3))))
# r2 = list(map(lambda x: x[0]+x[1], zip(test3, torch.sqrt(u3))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(3), alpha=0.2)

line4, = plt.plot(list(range(1,TIMES + 1)), u4.sqrt(),color=palette(4))
# r1 = list(map(lambda x: x[0]-x[1], zip(test4, torch.sqrt(u4))))
# r2 = list(map(lambda x: x[0]+x[1], zip(test4, torch.sqrt(u4))))
# plt.fill_between(list(range(1,TIMES + 1)), r1, r2, color=palette(4), alpha=0.2)
plt.legend((line0, line1, line2, line3, line4), ['URS', '1-PD', '2-PD', '3-PD', '4-PD'],prop=fontdict1)
plt.savefig('./img/new10/shots_test_std.pdf')
plt.clf()








