import numpy as np
import copy
import pickle
import random
import warnings
from PIL import Image
from tqdm import tqdm
# from scipy.io import loadmat
import seaborn as sns
import matplotlib.pyplot as plt

plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
from utils.loggable import Loggable
from utils.pathselection import common_path


def rk4(h, y, inputs, f):
    k1 = f(y, inputs)
    k2 = f(y + h / 2 * k1, inputs)
    k3 = f(y + h / 2 * k2, inputs)
    k4 = f(y + h * k3, inputs)

    y_new = y + h / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
    return y_new


class CANN1D(Loggable):
    def __init__(self, num, tau=1., k=8.1, a=0.5, A=10., J0=4., z_min=-np.pi, z_max=np.pi):
        # super(CANN1D, self).__init__(size=num)

        # parameters
        self.tau = tau  # The synaptic time constant
        self.k = k  # Degree of the rescaled inhibition
        self.a = a  # Half-width of the range of excitatory connections
        self.A = A  # Magnitude of the external input
        self.J0 = J0  # maximum connection value

        # feature space
        self.z_min = z_min
        self.z_max = z_max
        self.z_range = z_max - z_min
        self.x = np.linspace(z_min, z_max, num)  # The encoded feature values
        self.rho = num / self.z_range  # The neural density
        self.dx = self.z_range / num  # The stimulus density

        # variables
        self.u = np.zeros(num)
        self.input = np.zeros(num)

        # The connection matrix
        self.conn_mat = self.make_conn(self.x)

        self.u_new = copy.deepcopy(self.u)

    def derivative(self, u, Iext):
        r1 = np.square(u)
        r2 = 1.0 + self.k * np.sum(r1)
        r = r1 / r2
        Irec = np.dot(self.conn_mat, r)
        du = (-u + Irec + Iext) / self.tau
        return du

    def dist(self, d):
        d = np.remainder(d, self.z_range)
        d = np.where(d > 0.5 * self.z_range, d - self.z_range, d)
        return d

    def make_conn(self, x):
        assert np.ndim(x) == 1
        x_left = np.reshape(x, (-1, 1))
        x_right = np.repeat(x.reshape((1, -1)), len(x), axis=0)
        d = self.dist(x_left - x_right)
        Jxx = self.J0 * np.exp(-0.5 * np.square(d / self.a)) / (np.sqrt(2 * np.pi) * self.a)
        return Jxx

    def get_stimulus_by_pos(self, pos):
        return self.A * np.exp(-0.25 * np.square(self.dist(self.x - pos) / self.a))

    def update(self):
        self.u = copy.deepcopy(self.u_new)
        pass

    def step(self, dt, inputs=0):
        # if state is None:
        state = self.u
        temp = rk4(dt, state, inputs, self.derivative)
        self.u_new = temp
        return self.u_new


class FlyLSH(object):
    def __init__(self, sample_dim, hash_length, sampling_ratio, embedding_size):
        """
        data: uxd matrix
        hash_length: scalar
        sampling_ratio: fraction of input dims to sample from when producing a hash
        embedding_size: dimensionality of projection space, m
        Note that in Flylsh, the hash length and embedding_size are NOT the same
        whereas in usual LSH they are
        """
        self.sample_dim = sample_dim
        self.hash_length = hash_length
        self.embedding_size = embedding_size
        self.K = embedding_size // hash_length
        self.num_projections = 400  # int(sampling_ratio * self.sample_dim) #40

        self.maxl1distance = 2 * self.hash_length
        self.max_index_of_generated_weights = 0
        self.weights = np.zeros((self.sample_dim, self.embedding_size), dtype=bool)

    def generate_weights(self, data, code_dims=None, distribution_random=False):
        if code_dims is None:
            code_dims_indexs = np.arange(self.max_index_of_generated_weights, self.embedding_size, dtype=np.int64)
            code_dims_length = int(self.embedding_size - self.max_index_of_generated_weights)
        elif np.isscalar(code_dims):
            if code_dims < 1:
                code_dims = int(code_dims * self.embedding_size)
            if self.max_index_of_generated_weights >= code_dims:
                warnings.warn('Unexpected modification on existing connection!, max_index=%d, code_dims=%d' % (
                    self.max_index_of_generated_weights, code_dims))
            code_dims_indexs = np.arange(self.max_index_of_generated_weights, code_dims, dtype=np.long)
            code_dims_length = int(code_dims - self.max_index_of_generated_weights)
        else:
            if not np.all(np.less(self.max_index_of_generated_weights, code_dims)):
                warnings.warn('Unexpected modification on existing connection!, max_index=%d, code_dims=%d' % (
                    self.max_index_of_generated_weights, code_dims))
            code_dims_indexs = code_dims
            code_dims_length = len(code_dims)

        weights = np.random.random((self.sample_dim, code_dims_length))
        if distribution_random:
            weights = weights * np.mean(data, axis=0)[:, None]
        yindices = code_dims_indexs[None, :]
        xindices = weights.argsort(axis=0)[-self.num_projections:, :]
        self.weights[xindices, yindices] = True  # sparse projection vectors
        self.max_index_of_generated_weights += code_dims_length

    def PNtoKC(self, train_data, test_data, reset_num, center_data=False):
        if center_data:
            data_mean = np.mean(train_data, axis=1)
            self.train_data = (train_data - data_mean[:, None])
            data_mean = np.mean(test_data, axis=1)
            self.test_data = (test_data - data_mean[:, None])
        else:
            self.train_data = train_data
            self.test_data = test_data
        all_activations = (self.train_data @ self.weights)

        xindices_traindata = np.arange(train_data.shape[0])[:, None]
        # self.yindices = all_activations.argsort(axis=1)[:, -self.hash_length:]
        self.yindices_traindata = all_activations.argsort(axis=1)[:, -128:]
        self.kc_activity1_traindata = np.zeros_like(all_activations, dtype=bool)
        self.kc_activity1_traindata[xindices_traindata, self.yindices_traindata] = True  # choose topk activations

        self.kc_activity2_traindata = np.zeros_like(all_activations)
        self.kc_activity2_traindata[xindices_traindata, self.yindices_traindata] = all_activations[xindices_traindata, self.yindices_traindata]

        self.kc_activity3_traindata = np.zeros_like(all_activations)
        self.kc_activity3_traindata[xindices_traindata, self.yindices_traindata] = all_activations[xindices_traindata, self.yindices_traindata]
        print('kc_traindata')


        self.kc_on_off_traindata = np.ones_like(all_activations, dtype=bool)
        restrainnum_kc = np.count_nonzero(self.kc_activity3_traindata, axis=0)
        restrainnum_kc_id = np.where(restrainnum_kc > reset_num)[0]

        self.kc_activity3_traindata[:, restrainnum_kc_id] = 0
        self.kc_on_off_traindata[:, restrainnum_kc_id] = 0

        all_activations_restrain = np.zeros_like(all_activations)
        all_activations_restrain[self.kc_on_off_traindata] = all_activations[self.kc_on_off_traindata]

        yindices_traindata = all_activations_restrain.argsort(axis=1)[:, -128:]
        self.kc_activity3_traindata[xindices_traindata, yindices_traindata] = all_activations_restrain[xindices_traindata, yindices_traindata]
        self.kc_activity3_traindata = self.kc_activity3_traindata / np.sum(self.kc_activity3_traindata, axis=1, keepdims=1)
        self.kc_activity3_traindata = (self.kc_activity3_traindata - np.min(self.kc_activity3_traindata)) / (
                np.max(self.kc_activity3_traindata) - np.min(self.kc_activity3_traindata))
        print('kc_traindata_restrain')


        all_activations = (self.test_data @ self.weights)
        xindices_testdata = np.arange(test_data.shape[0])[:, None]
        self.yindices_testdata = all_activations.argsort(axis=1)[:, -128:]
        self.kc_activity1_testdata= np.zeros_like(all_activations, dtype=bool)
        self.kc_activity1_testdata[xindices_testdata, self.yindices_testdata] = True  # choose topk activations

        self.kc_activity2_testdata = np.zeros_like(all_activations)
        self.kc_activity2_testdata[xindices_testdata, self.yindices_testdata] = all_activations[xindices_testdata, self.yindices_testdata]

        self.kc_activity3_testdata = np.zeros_like(all_activations)
        self.kc_activity3_testdata[xindices_testdata, self.yindices_testdata] = all_activations[xindices_testdata, self.yindices_testdata]
        print('kc_testdata')


        self.kc_on_off_testdata = np.ones_like(all_activations, dtype=bool)
        restrainnum_kc = np.count_nonzero(self.kc_activity3_testdata, axis=0)
        restrainnum_kc_id = np.where(restrainnum_kc > reset_num)[0]

        self.kc_activity3_testdata[:, restrainnum_kc_id] = 0
        self.kc_on_off_testdata[:, restrainnum_kc_id] = 0

        all_activations_restrain = np.zeros_like(all_activations)
        all_activations_restrain[self.kc_on_off_testdata] = all_activations[self.kc_on_off_testdata]

        yindices_testdata = all_activations_restrain.argsort(axis=1)[:, -128:]
        self.kc_activity3_testdata[xindices_testdata, yindices_testdata] = all_activations_restrain[xindices_testdata, yindices_testdata]
        self.kc_activity3_testdata = self.kc_activity3_testdata / np.sum(self.kc_activity3_testdata, axis=1,keepdims=1)
        self.kc_activity3_testdata = (self.kc_activity3_testdata - np.min(self.kc_activity3_testdata)) / (
                np.max(self.kc_activity3_testdata) - np.min(self.kc_activity3_testdata))
        print('kc_testdata_restrain')

    def KCtoMBON(self, train_data, train_lables, center_data=False):
        if center_data:
            data_mean = np.mean(train_data, axis=1)
            self.train_data = (train_data - data_mean[:, None])
        else:
            self.train_data = train_data

        self.kc_mbon_weight1 = np.zeros((self.embedding_size, 72), dtype=bool)
        self.kc_mbon_weight2 = np.zeros((self.embedding_size, 72))
        self.kc_mbon_weight3 = np.zeros((self.embedding_size, 72))
        lr = np.ones((self.embedding_size, 72))


        for i in range(train_data.shape[0]):
            indices1 = np.argsort(self.kc_activity1_traindata[i])[-128:]  # 第i个图片对应的最强的2个KC的下标
            self.kc_mbon_weight1[indices1, train_lables[i]] = True  # 第i个图片对应的最强的2个KC连接到mbon上（第i个图片对应的mbon）

            indices2 = np.argsort(self.kc_activity2_traindata[i])[-128:]
            self.kc_mbon_weight2[indices2, train_lables[i]] = self.kc_mbon_weight2[indices2, train_lables[i]] + self.kc_activity2_traindata[
                i, indices2]


            indices3 = np.argsort(self.kc_activity3_traindata[i])[-128:]
            active_kc = self.kc_activity3_traindata[i, indices3]
            for j in range(i + 1):
                lr[indices3, train_lables[j]] = lr[indices3, train_lables[j]] * 0.9999  # 0.9999 0.95
            self.kc_mbon_weight3[indices3, train_lables[i]] = np.multiply(
                (active_kc - self.kc_mbon_weight3[indices3, train_lables[i]]), lr[indices3, train_lables[i]]) + self.kc_mbon_weight3[
                                                           indices3, train_lables[i]]


        # kc_activity1: max 01; kc_activity1: max 01
        self.img_mbon1_traindata = (self.kc_activity1_traindata @ self.kc_mbon_weight1)  # 离散 离散
        self.img_mbon2_traindata = np.matmul(self.kc_activity2_traindata, self.kc_mbon_weight1)  # 连续 离散
        self.img_mbon3_traindata = np.matmul(self.kc_activity2_traindata, self.kc_mbon_weight2)  # 连续 连续
        self.img_mbon4_traindata = np.matmul(self.kc_activity3_traindata, self.kc_mbon_weight3)  # 连续 连续 学习率
        print('mbon_traindata')

        self.img_mbon1_testdata = (self.kc_activity1_testdata @ self.kc_mbon_weight1)  # 离散 离散
        self.img_mbon2_testdata = np.matmul(self.kc_activity2_testdata, self.kc_mbon_weight1)  # 连续 离散
        self.img_mbon3_testdata = np.matmul(self.kc_activity2_testdata, self.kc_mbon_weight2)  # 连续 连续
        self.img_mbon4_testdata = np.matmul(self.kc_activity3_testdata, self.kc_mbon_weight3)  # 连续 连续 学习率
        print('mbon_testdata')

        input = np.zeros_like(self.img_mbon4_testdata)
        mbon_min = np.min(self.img_mbon4_testdata[:, 0::2], axis=1)[:, None]
        mbon_max = np.max(self.img_mbon4_testdata[:, 0::2], axis=1)[:, None]
        input[:, 0::2] = (self.img_mbon4_testdata[:, 0::2] - mbon_min) / (mbon_max - mbon_min) / 10
        self.img_cann_test = np.zeros_like(self.img_mbon4_testdata)
        for id in range(0, input.shape[0]):
            dt = 0.01
            t_start = 0
            t_end = 10
            times = np.arange(t_start, t_end, dt)
            path_dict = common_path(experiment='CANN')
            # 初始化对象
            cann = CANN1D(num=72, tau=1., k=0.1, a=0.1, A=1, J0=4., z_min=-np.pi, z_max=np.pi)
            cann.init_recording(name_list=['u'], log_path=path_dict['result_path'], log_name='CANN')
            us=[]
            # 数值积分
            for t in times:
                u = cann.step(dt, input[id])
                us.append(u)
                # cann.recording()
                cann.update()
            # cann.save_recording()
            # trace = cann.retrieve_record()
            self.img_cann_test[id] = us[-1]#trace['u'][-1]

    def MBONtoCANN(self, test_start, test_end, center_data=False):
        print('cann_testdata')
        input = np.zeros_like(self.img_mbon4_testdata)
        mbon_min = np.min(self.img_mbon4_testdata[:, 0::2], axis=1)[:, None]
        mbon_max = np.max(self.img_mbon4_testdata[:, 0::2], axis=1)[:, None]
        input[:, 0::2] = (self.img_mbon4_testdata[:, 0::2] - mbon_min) / (mbon_max - mbon_min)

        self.img_cann_test = np.zeros_like(self.img_mbon4_testdata)
        part_mbon=[]
        for id in range(test_start, test_end):
            dt = 0.01
            t_start = 0
            t_end = 1
            times = np.arange(t_start, t_end, dt)
            path_dict = common_path(experiment='CANN')
            # 初始化对象
            cann = CANN1D(num=72, tau=1., k=0.1, a=0.1, A=1, J0=4., z_min=-np.pi, z_max=np.pi)
            cann.init_recording(name_list=['u'], log_path=path_dict['result_path'], log_name='CANN')

            # 数值积分
            for t in times:
                u = cann.step(dt, input[id])
                cann.recording()
                cann.update()
            cann.save_recording()
            trace = cann.retrieve_record()
            self.img_cann_test[id] = trace['u'][-1]
            print('end')
        # np.savetxt('./saved_csv/cann/'+str(test_start)+'.csv', part_mbon, delimiter=',')        print('end')
    def plot_maxMB(self, images, lables, max_5_mbons):
        plt.axes().get_xaxis().set_visible(False)  # 隐藏x坐标轴
        plt.axes().get_yaxis().set_visible(False)  # 隐藏y坐标轴
        lables = lables
        max_5_mbons = max_5_mbons

        for i in range(len(images)):
            plt.subplot(1, len(images), i + 1)
            plt.imshow(images[i].reshape((128, 128, 3)) / 255)
            xlabel = 'ori_angle' + str(lables[i] * 5)
            plt.xlabel(xlabel)
            title = 'max_5_mbons' + str(max_5_mbons[i] * 5)
            plt.title(title)
        plt.show()

    def predict(self, images, lables, images_nonave):
        right_angle = 0
        means=[]
        vars=[]
        for img_id in range(images.shape[0]):
            max_indices = np.argsort(self.img_cann_test[img_id])[-5:]
            mean = np.mean(max_indices*5)
            means.append(mean)
            var = np.var(max_indices*5)
            vars.append(var)
            if ((lables[img_id]) in max_indices):  # 方法3
                right_angle = right_angle + 1
        np.savetxt('./saved_csv/cann/means.txt', means, delimiter=',')
        np.savetxt('./saved_csv/cann/vars.txt', vars, delimiter=',')
        acc = right_angle / images.shape[0]
        print('predict')
        print(acc)
        return acc

    def lifelong_predict(self, i, images, lables, images_nonave):
        acc1 = np.zeros(i)
        acc2 = np.zeros(i)
        acc3 = np.zeros(i)
        for id in range(i):
            right_angle2 = 0
            right_angle3 = 0
            right_angle4 = 0
            start_index = id * 72
            end_index = min((id + 1) * 72, images.shape[0])
            for img_id in range(start_index, end_index):
                max_indices2 = np.argsort(self.img_mbon2_traindata[img_id])[-1:]

                max_indices3 = np.argsort(self.img_mbon3_traindata[img_id])[-1:]

                max_indices4 = np.argsort(self.img_mbon4_traindata[img_id])[-1:]

                if lables[img_id] in max_indices2:
                    right_angle2 = right_angle2 + 1

                if lables[img_id] in max_indices3:
                    right_angle3 = right_angle3 + 1

                if lables[img_id] in max_indices4:
                    right_angle4 = right_angle4 + 1

            acc1[id] = right_angle4 / 72  # KC浮点数-Syn浮点数（权重）-MBON浮点数

            acc2[id] = right_angle3 / 72  # KC浮点数-Syn浮点数（直接相加）-MBON浮点数

            acc3[id] = right_angle2 / 72  # KC浮点数-Syn二进制-MBON浮点数
        return acc1, acc2, acc3

def single_test(hash_length, embedding_size, train_set, test_set, sampling_ratio,
                all_expriments, train_angles, test_angles, images_nonave, if_center_data=True):
    seed = hash_length * embedding_size * sampling_ratio
    random.seed(seed)
    np.random.seed(int(seed))
    model = {}
    functions = {'Fly': FlyLSH}
    sample_dim = train_set.shape[1]
    reset_num = int(0.25*train_set.shape[0])
    for expriment in all_expriments:  # ['Fly', 'FlylshDevelop', 'FlylshDevelopThreshold','FlylshDevelopThresholdRandomChoice']
        model[expriment] = functions[expriment](sample_dim, hash_length, sampling_ratio, embedding_size)
        model[expriment].generate_weights(train_set)
        model[expriment].PNtoKC(train_set, test_set, reset_num=int(reset_num/10), center_data=if_center_data)
        model[expriment].KCtoMBON(train_set, train_angles, center_data=if_center_data)
        acc = model[expriment].predict(test_set, test_angles, images_nonave)
        return acc


def lifelong_test(hash_length, embedding_size, training_data, testing_data, sampling_ratio,
                  all_expriments, lables, images_nonave, if_center_data=True):
    seed = hash_length * embedding_size * sampling_ratio
    random.seed(seed)
    np.random.seed(int(seed))
    model = {}
    functions = {'Fly': FlyLSH}
    sample_dim = training_data.shape[1]
    for expriment in all_expriments:  # ['Fly', 'FlylshDevelop', 'FlylshDevelopThreshold','FlylshDevelopThresholdRandomChoice']

        accs1 = []
        accs2 = []
        accs3 = []
        for i in range(0, 87):  # 13 #87
            # acc = np.zeros((i + 1, 3))
            images = training_data[:72 * (i + 1), :]
            lables_ = lables[:72 * (i + 1)]
            model[expriment] = functions[expriment](sample_dim, hash_length, sampling_ratio, embedding_size)
            model[expriment].generate_weights(images)
            model[expriment].PNtoKC(images, center_data=if_center_data)
            model[expriment].KCtoMBON(images, lables, center_data=if_center_data)
            acc1, acc2, acc3 = model[expriment].lifelong_predict(i + 1, images, lables_, images_nonave)
            accs1.append(acc1)
            accs2.append(acc2)
            accs3.append(acc3)
            print(i)

            # for j in range(i+1):
            #     images_j = images[72 * j:72 * (j + 1), :]
            #     lables_j = lables_[72 * j:72 * (j + 1)]
            # for a in range(72):
            #     plt.imshow(images_j[a].reshape((128, 128, 3)) / 255)
            #     plt.show()

            #     acc[j] = model[expriment].lifelong_predict(j, images_j, lables_j, images_nonave)
            # accs.append(acc)
        result1 = np.zeros((87, 87))  # ((100, 100))#((87, 87))#((13, 13))
        for i, row in enumerate(accs1):
            result1[i][:len(row)] = row
        result2 = np.zeros((87, 87))
        for i, row in enumerate(accs2):
            result2[i][:len(row)] = row
        result3 = np.zero((87, 87))
        for i, row in enumerate(accs3):
            result3[i][:len(row)] = row
        np.savetxt('./saved_csv/coil_del/longlife_sim_me1.csv', result1, delimiter=',')
        np.savetxt('./saved_csv/coil_del/longlife_sim_me2.csv', result2, delimiter=',')
        np.savetxt('./saved_csv/coil_del/longlife_sim_me3.csv', result3, delimiter=',')

        # with open("./saved_csv/lifelong_accs_ave.txt", "w") as file:
        #     for item in accs:
        #         file.write(str(item) + "\n")
    # np.savetxt('longlife_accs.txt', accs, delimiter=',')
        return result3


def hugeimage_test(hash_length, embedding_size, training_data, testing_data, sampling_ratio,
                   all_expriments, lables, images_nonave, if_center_data=True):
    seed = hash_length * embedding_size * sampling_ratio
    random.seed(seed)
    np.random.seed(int(seed))
    model = {}
    functions = {'Fly': FlyLSH}
    sample_dim = training_data.shape[1]
    for expriment in all_expriments:  # ['Fly', 'FlylshDevelop', 'FlylshDevelopThreshold','FlylshDevelopThresholdRandomChoice']
        accnums = np.zeros(3)
        for i in range(0, 72001, 7200):
            images = training_data[7200 * i:7200 * (i + 1), :]
            lables_ = lables[7200 * i: 7200 * (i + 1)]
            model[expriment] = functions[expriment](sample_dim, hash_length, sampling_ratio, embedding_size)
            model[expriment].generate_weights(images)
            model[expriment].PNtoKC(images, center_data=if_center_data)
            model[expriment].KCtoMBON(images, lables_, center_data=if_center_data)
            accnum = model[expriment].predict_accnums(testing_data, lables, images_nonave)
            accnums = accnums + accnum
            print(accnum)
        acc = accnums / images.shape[0]
        return acc


if __name__ == '__main__':

    max_index = 10000
    sampling_ratio = 0.10
    nnn = 200  # number of nearest neighbours to compare, 2% of max_index as in paper
    hash_lengths = [512]
    number_of_tests = 1
    ratio = 20
    result_root_path = "./Results/NoPreprocessing/"  # "./Results/main/"
    lifelong = False
    hugeimge = False
    sorted_data = False
    number_of_process = 1

    threshold = 10
    if_center_data = False
    all_MAPs = {}

    train_set = []
    test_set = []
    train_angles = []
    test_angles = []
    images_sim_all = np.loadtxt('./saved_csv/coil_image_del.csv', delimiter=',')  # , max_rows=144 , max_rows=936 , max_rows=14800
    angles = np.loadtxt('./saved_csv/coil_lable_del.csv', dtype=int, delimiter=',')  # , max_rows=144
    lables_all = angles // 5  # sim
    for i in range(0, len(images_sim_all)):
        if i % 2 == 0:
            train_set.append(images_sim_all[i])
            train_angles.append(lables_all[i])
        else:
            test_set.append(images_sim_all[i])
            test_angles.append(lables_all[i])
    train_set = np.array(train_set)
    test_set = np.array(test_set)
    train_angles = np.array(train_angles)
    test_angles = np.array(test_angles)
    # lables_all = angles #coil-ori aloi
    print('loaded')
    images_nonave = images_sim_all
    images_ave = (images_sim_all - np.mean(images_sim_all, axis=1)[:, None])
    # 5.初始化准确率列表
    all_expriments = ['Fly']

    img_accs = []  # 87

    if lifelong:
        for hash_length in hash_lengths:  # k
            print('life-long')
            embedding_size = int(
                ratio * hash_length)  # int(10*input_dim) #20k or 10d  #20*[2, 4, 8, 12, 16, 20, 24, 28, 32]
            acc = lifelong_test(hash_length, embedding_size, train_set, test_set, sampling_ratio,
                                all_expriments, train_angles, test_angles, images_nonave, if_center_data)
            print(acc[:, :10])
            img_accs.append(acc)

    else:
        for hash_length in hash_lengths:
            print(hash_length)  # k
            embedding_size = int(20 * hash_length)  # int(10*input_dim) #20k or 10d
            acc_nonave = single_test(hash_length, embedding_size, train_set, test_set, sampling_ratio,
                                all_expriments, train_angles, test_angles, images_nonave, if_center_data)

            print(acc_nonave)
            np.savetxt('./saved_csv/coil_del_cann_acc_nonave.txt', acc_nonave, delimiter=',')

