import numpy as np
import tensorflow as tf
import random

x_train=np.load("cifar10/training_data.npy")
y_train=np.load("cifar10/training_label.npy")
x_test=np.load("cifar10/testing_data.npy")
y_test=np.load("cifar10/testing_label.npy")
"""
num=10
list_train=[]
for i in range(num):
    ind=np.where(y_train==np.array([i]))[0]
    index=random.sample(list(ind),5000)
    list_train=list_train+index
xx=x_train[list_train]
yy=y_train[list_train]
rho=0.5
"""
_=np.load('imagenet_plants.npz')
background=_['arr_0']

def mask_back(b,c, rho, size, data):

    msk1=[]
    for i in range(32):
        k=32*i
        if i>=b and i<b+size:
            list1=list(np.arange(c)+k)+list(np.arange(c+size,32)+k)
            list2=list(np.arange(c,c+size)+k)
            list3=list2+random.sample(list1, int(len(list1)*(1-rho)))
            msk1=msk1+list3
        else:
            msk1=msk1+random.sample(list(np.arange(32)+k),int((1-rho)*32))
    msk1.sort()
    mskk=np.zeros(32*32)
    mskk[msk1]=1
    matrix_mask=np.array(mskk).reshape([32,32])
    matrix_mask=np.array(matrix_mask).reshape([32,32,1])
    tensor_mask=matrix_mask*np.ones([1,3])
    return np.multiply(data, tensor_mask)

import matplotlib.pyplot as plt
img=background[10,:,:,:]
dt=x_train[199,:,:,:]/255
img[3:3+24,5:5+24,:]=tf.image.resize(tf.convert_to_tensor([dt]), size=(24, 24))
img=mask_back(3,5,0.5,24,img)
plt.imshow(img)
plt.axis('off')
plt.savefig('./rho05.png', bbox_inches='tight')
plt.show()

"""
x_train_new=np.zeros(x_train.shape)
for i in range(x_train.shape[0]):
    a=random.randint(0,background.shape[0]-1)
    x_train_new[i,:,:,:]=background[a,:,:,:]
    #x_train_new[i,:,:,:]=np.zeros([32,32,3])
    b=random.randint(0,8)
    c= random.randint(0,8)
    dt=x_train[i,:,:,:]/255
    #x_train_new[i,b:b+16, c:c+16, :]=dt[0:dt.shape[0]:2, 0:dt.shape[1]:2,:]
    x_train_new[i,b:b+24,c:c+24,:]=tf.image.resize(tf.convert_to_tensor([dt]), size=(24, 24))
    x_train_new[i,:,:,:]=mask_back(b,c,0,24,x_train_new[i,:,:,:])

import matplotlib.pyplot as plt
plt.imshow(x_train_new[1,:,:,:])
plt.imshow(x_train_new[2,:,:,:])
plt.show()
"""