import numpy as np
import pickle
import random
import warnings
from PIL import Image
from tqdm import tqdm
# from scipy.io import loadmat
import seaborn as sns
import matplotlib.pyplot as plt
import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')

plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号


class CANN1D(bp.NeuGroup):
    def __init__(self, num, tau=1., k=8.1, a=0.5, A=10., J0=4., z_min=-bm.pi, z_max=bm.pi):
        super(CANN1D, self).__init__(size=num)

        # parameters
        self.tau = tau  # The synaptic time constant
        self.k = k  # Degree of the rescaled inhibition
        self.a = a  # Half-width of the range of excitatory connections
        self.A = A  # Magnitude of the external input
        self.J0 = J0  # maximum connection value

        # feature space
        self.z_min = z_min
        self.z_max = z_max
        self.z_range = z_max - z_min
        self.x = bm.linspace(z_min, z_max, num)  # The encoded feature values
        self.rho = num / self.z_range  # The neural density
        self.dx = self.z_range / num  # The stimulus density

        # variables
        self.u = bm.Variable(bm.zeros(num))
        self.input = bm.Variable(bm.zeros(num))

        # The connection matrix
        self.conn_mat = self.make_conn(self.x)

        # function
        self.integral = bp.odeint(self.derivative)

    def derivative(self, u, t, Iext):
        r1 = bm.square(u)
        r2 = 1.0 + self.k * bm.sum(r1)
        r = r1 / r2
        Irec = bm.dot(self.conn_mat, r)
        du = (-u + Irec + Iext) / self.tau
        return du

    def dist(self, d):
        d = bm.remainder(d, self.z_range)
        d = bm.where(d > 0.5 * self.z_range, d - self.z_range, d)
        return d

    def make_conn(self, x):
        assert bm.ndim(x) == 1
        x_left = bm.reshape(x, (-1, 1))
        x_right = bm.repeat(x.reshape((1, -1)), len(x), axis=0)
        d = self.dist(x_left - x_right)
        Jxx = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a)
        return Jxx

    def get_stimulus_by_pos(self, pos):
        return self.A * bm.exp(-0.25 * bm.square(self.dist(self.x - pos) / self.a))

    def update(self, tdi):
        self.u.value = self.integral(self.u, tdi.t, self.input, tdi.dt)
        self.input[:] = 0.

    def cell(self, u):
        return self.derivative(u, 0., 0.)


class LSH(object):
    def __init__(self, sample_dim, hash_length):
        """
        data: uxd matrix
        hash_length: scalar
        sampling_ratio: fraction of input dims to sample from when producing a hash
        (ratio of PNs that each KC samples from)
        embedding_size: dimensionality of projection space, m
        """
        self.sample_dim = sample_dim
        self.hash_length = hash_length
        self.maxl1distance = 2 * self.hash_length
        self.weights = np.zeros((self.sample_dim, self.hash_length))
        self.max_index_of_generated_weights = 0

    def generate_weights(self, code_dims=None):
        if code_dims is None:
            # Develop til full
            self.weights[:, self.max_index_of_generated_weights:] = \
                np.random.random((self.sample_dim, self.hash_length - self.max_index_of_generated_weights))
            self.max_index_of_generated_weights = self.hash_length
        elif np.isscalar(code_dims):
            if code_dims < 1:
                code_dims = int(code_dims * self.hash_length)
            # Develop code_dims more
            if self.max_index_of_generated_weights >= code_dims:
                warnings.warn('Unexpected modification on existing connection!')
            self.weights[:, self.max_index_of_generated_weights:self.max_index_of_generated_weights + code_dims] = \
                np.random.random((self.sample_dim, code_dims))
            self.max_index_of_generated_weights += code_dims

        else:
            # Develop specific code dims following the elements in code_dims as indexes
            if not np.all(np.less(self.max_index_of_generated_weights, code_dims)):
                warnings.warn('Unexpected modification on existing connection!')
            self.weights[:, code_dims] = np.random.random((self.sample_dim, len(code_dims)))
            self.max_index_of_generated_weights += len(code_dims)

    def hashing(self, data):
        self.data = data - np.mean(data, axis=1)[:, None]
        self.hashes = (data @ self.weights) > 0

    def query(self, qidx, nnn, not_olap=False):
        L1_distances = np.sum(np.abs(self.hashes[qidx, :] ^ self.hashes), axis=1)
        nnn = min(self.hashes.shape[0], nnn)
        if not_olap:
            no_overlaps = np.sum(L1_distances == self.maxl1distance)
            return no_overlaps

        NNs = L1_distances.argsort()
        NNs = NNs[(NNs != qidx)][:nnn]
        # print(L1_distances[NNs]) #an interesting property of this hash is that the L1 distances are always even
        return NNs

    def true_nns(self, qidx, nnn):
        sample = self.data[qidx, :]
        tnns = np.sum((self.data - sample) ** 2, axis=1).argsort()[:nnn + 1]
        tnns = tnns[(tnns != qidx)]
        if nnn < self.data.shape[0]:
            assert len(tnns) == nnn, 'nnn={}'.format(nnn)
        return tnns

    def construct_true_nns(self, indices, nnn):
        all_NNs = np.zeros((len(indices), nnn))
        for idx1, idx2 in enumerate(indices):
            all_NNs[idx1, :] = self.true_nns(idx2, nnn)
        return all_NNs

    def AP(self, predictions, truth):
        assert len(predictions) == len(truth) or len(predictions) == self.hashes.shape[0]
        # removed conversion to list in next line:
        precisions = [len((set(predictions[:idx]).intersection(set(truth[:idx])))) / idx for \
                      idx in range(1, len(truth) + 1)]
        return np.mean(precisions)

    def PR(self, qidx, truth, atindices):
        """truth should be a set"""
        L1_distances = np.sum((self.hashes[qidx, :] ^ self.hashes), axis=1)
        NNs = L1_distances.argsort()
        NNs = NNs[(NNs != qidx)]
        # predictions=NNs
        recalls = np.arange(1, len(truth) + 1)
        all_recalls = [len(set(NNs[:idx]) & truth) for idx in atindices]
        # all_recalls.append(len(set(NNs)&truth))
        # all_recalls=[len(set(predictions[:idx])&truth) for idx in range(1,self.hashes.shape[0]+1)]
        # indices=[all_recalls.index(recall) for recall in recalls]
        precisions = [recall / (idx + 1) for idx, recall in zip(atindices, all_recalls)]
        # this_pr=odict({l:(p,r) for l,p,r in zip(atL1,precisions,recalls)})
        return precisions, all_recalls  # (precisions,all_recalls)

    def ROC(self, qidx, truth, atindices):
        """x: False positive rate, y: True positive rate, truth should be a set"""
        L1_distances = np.sum((self.hashes[qidx, :] ^ self.hashes), axis=1)
        NNs = L1_distances.argsort()
        NNs = NNs[(NNs != qidx)]
        x, y = [], []
        for idx in atindices:
            ntruepos = len((set(NNs[:idx]) & truth))  # number of positives correctly classified
            nfalseneg = idx - ntruepos  # number of negatives incorrectly classified
            tpr = ntruepos / len(truth)  # positives correctly classified / total positives
            fpr = nfalseneg / (len(NNs) - len(truth))  # negatives incorrectly classified / total negatives
            x.append(fpr)
            y.append(tpr)
        return x, y

    def findmAP_given_true(self, nnn, n_points, all_NNs):
        return np.mean(self.findAPs_given_true(nnn, n_points, all_NNs))

    def findAPs_given_true(self, nnn, n_points, all_NNs):
        sample_indices = np.random.choice(self.data.shape[0], n_points)
        self.allAPs = []
        for eidx, didx in enumerate(sample_indices):
            # eidx: enumeration id, didx: index of sample in self.data
            this_nns = self.query(didx, nnn)
            # print(len(this_nns))
            this_AP = self.AP(list(this_nns), list(all_NNs[didx, :]))
            # print(this_AP)
            self.allAPs.append(this_AP)
        return self.allAPs

    def findmAP_given_true_labels(self, nnn, n_points, all_NNs, labels, label_set=None):
        if label_set is None:
            label_set = set(labels)
        self.findAPs_given_true(nnn, n_points, all_NNs)
        self.mAP_per_lable = {}
        for a_label in label_set:
            self.mAP_per_lable[a_label] = np.mean(np.array(self.allAPs)[labels == a_label])
        return self.mAP_per_lable

    def findmAP(self, nnn, n_points):
        start = np.random.randint(low=0, high=self.data.shape[0] - n_points)
        sample_indices = np.random.choice(self.data.shape[0], n_points)
        all_NNs = self.construct_true_nns(sample_indices, nnn)
        self.allAPs = []
        for eidx, didx in enumerate(sample_indices):
            # eidx: enumeration id, didx: index of sample in self.data
            this_nns = self.query(didx, nnn)
            # print(len(this_nns))
            this_AP = self.AP(list(this_nns), list(all_NNs[eidx, :]))
            # print(this_AP)
            self.allAPs.append(this_AP)
        return np.mean(self.allAPs)

    def findZKk(self, n_points):
        """
        ZKk is the number of vectors whose overlap with a specific vector is zero
        """
        sample_indices = np.random.choice(self.data.shape[0], n_points)
        no_overlaps = []
        for eidx, didx in enumerate(sample_indices):
            no_overlaps.append(self.query(didx, -20, not_olap=True))
        return np.mean(no_overlaps)

    def computePRC(self, n_points=1, nnn=200, atindices=None):
        """
        This function calculates precision-recall metrics for model
        """

        def replacenans(x):
            nanidxs = [idx for idx in range(len(x)) if np.isnan(x[idx])]
            notnang = lambda idx: [nidx for nidx in range(idx + 1, len(x)) if nidx not in nanidxs][0]
            notnans = lambda idx: [nidx for nidx in range(idx) if nidx not in nanidxs][-1]
            if len(nanidxs) == 0:
                return x
            else:
                for nanidx in nanidxs:
                    if nanidx == 0:
                        x[nanidx] = x[notnang(nanidx)]
                    else:
                        x[nanidx] = (x[notnang(nanidx)] + x[notnans(nanidx)]) / 2
                return x

        sample_indices = np.random.choice(self.data.shape[0], n_points)
        all_NNs = self.construct_true_nns(sample_indices, nnn)
        self.allprecisions = np.zeros((n_points, len(atindices)))
        self.allrecalls = np.zeros((n_points, len(atindices)))
        # allprs=odict({l:[[],[]] for l in atL1})
        for eidx, didx in enumerate(sample_indices):
            """eidx: enumeration id, didx: index of sample in self.data"""
            this_p, this_r = self.PR(didx, set(all_NNs[eidx, :]), atindices)
            self.allprecisions[eidx, :] = this_p
            self.allrecalls[eidx, :] = this_r

        return [self.allprecisions.mean(axis=0),
                self.allrecalls.mean(axis=0)]  # replacenans([np.nanmean(v) for _,v in allprcs.items()])

    def computeROC(self, n_points=1, nnn=200, atindices=None):
        """
        This function calculates receiver operator characteristics (ROC)
        """
        sample_indices = np.random.choice(self.hashes.shape[0], n_points)
        all_NNs = self.construct_true_nns(sample_indices, nnn)
        alltprs = np.zeros((n_points, len(atindices)))
        allfprs = np.zeros((n_points, len(atindices)))
        for eidx, didx in enumerate(sample_indices):
            this_fpr, this_tpr = self.ROC(didx, set(all_NNs[eidx, :]), atindices)
            allfprs[eidx, :] = this_fpr
            alltprs[eidx, :] = this_tpr
        return [allfprs.mean(axis=0), alltprs.mean(axis=0)]


class FlyLSH(LSH):
    def __init__(self, sample_dim, hash_length, sampling_ratio, embedding_size):
        """
        data: uxd matrix
        hash_length: scalar
        sampling_ratio: fraction of input dims to sample from when producing a hash
        embedding_size: dimensionality of projection space, m
        Note that in Flylsh, the hash length and embedding_size are NOT the same
        whereas in usual LSH they are
        """
        self.sample_dim = sample_dim
        self.hash_length = hash_length
        self.embedding_size = embedding_size
        self.K = embedding_size // hash_length
        self.num_projections = 400  # int(sampling_ratio * self.sample_dim) #40

        self.maxl1distance = 2 * self.hash_length
        self.max_index_of_generated_weights = 0
        self.weights = np.zeros((self.sample_dim, self.embedding_size), dtype=bool)

    def generate_weights(self, data, code_dims=None, distribution_random=False):
        if code_dims is None:
            code_dims_indexs = np.arange(self.max_index_of_generated_weights, self.embedding_size, dtype=np.int64)
            code_dims_length = int(self.embedding_size - self.max_index_of_generated_weights)
        elif np.isscalar(code_dims):
            if code_dims < 1:
                code_dims = int(code_dims * self.embedding_size)
            if self.max_index_of_generated_weights >= code_dims:
                warnings.warn('Unexpected modification on existing connection!, max_index=%d, code_dims=%d' % (
                    self.max_index_of_generated_weights, code_dims))
            code_dims_indexs = np.arange(self.max_index_of_generated_weights, code_dims, dtype=np.long)
            code_dims_length = int(code_dims - self.max_index_of_generated_weights)
        else:
            if not np.all(np.less(self.max_index_of_generated_weights, code_dims)):
                warnings.warn('Unexpected modification on existing connection!, max_index=%d, code_dims=%d' % (
                    self.max_index_of_generated_weights, code_dims))
            code_dims_indexs = code_dims
            code_dims_length = len(code_dims)

        weights = np.random.random((self.sample_dim, code_dims_length))
        if distribution_random:
            weights = weights * np.mean(data, axis=0)[:, None]
        yindices = code_dims_indexs[None, :]
        xindices = weights.argsort(axis=0)[-self.num_projections:, :]
        self.weights[xindices, yindices] = True  # sparse projection vectors
        self.max_index_of_generated_weights += code_dims_length

    def PNtoKC(self, data,  center_data=False):
        if center_data:
            data_mean = np.mean(data, axis=1)
            self.data = (data - data_mean[:, None])
        else:
            self.data = data
        all_activations = (self.data @ self.weights)
        xindices = np.arange(data.shape[0])[:, None]
        # self.yindices = all_activations.argsort(axis=1)[:, -self.hash_length:]
        self.yindices = all_activations.argsort(axis=1)[:, -128:]
        self.kc_activity1 = np.zeros_like(all_activations, dtype=bool)
        self.kc_activity1[xindices, self.yindices] = True  # choose topk activations

        self.kc_activity2 = np.zeros_like(all_activations)
        self.kc_activity2[xindices, self.yindices] = all_activations[xindices, self.yindices]

        self.kc_activity3 = np.zeros_like(all_activations)
        self.kc_activity3[xindices, self.yindices] = all_activations[xindices, self.yindices]
        print('kc')


    def KCtoMBON(self, data, lables, center_data=False):
        if center_data:
            data_mean = np.mean(data, axis=1)
            self.data = (data - data_mean[:, None])
        else:
            self.data = data

        self.kc_mbon_weight1 = np.zeros((self.embedding_size, 72), dtype=bool)
        self.kc_mbon_weight2 = np.zeros((self.embedding_size, 72))
        self.kc_mbon_weight3 = np.zeros((self.embedding_size, 72))
        lr = np.ones((self.embedding_size, 72))


        for i in range(data.shape[0]):
            indices = self.yindices[i]  # 第i个图片对应的最强的2个KC的下标
            self.kc_mbon_weight1[indices, lables[i]] = True  # 第i个图片对应的最强的2个KC连接到mbon上（第i个图片对应的mbon）
            self.kc_mbon_weight2[indices, lables[i]] = self.kc_mbon_weight2[indices, lables[i]] + self.kc_activity2[
                i, indices]
            active_kc = self.kc_activity3[i, self.yindices[i]]
            for j in range(i + 1):
                lr[indices, lables[j]] = lr[indices, lables[j]] * 0.9999  # 0.9999 0.95
            self.kc_mbon_weight3[indices, lables[i]] = np.multiply(
                (active_kc - self.kc_mbon_weight3[indices, lables[i]]), lr[indices, lables[i]]) + self.kc_mbon_weight3[
                                                           indices, lables[i]]
            # self.kc_mbon_weight3[indices, lables[i]] = self.kc_mbon_weight3[indices, lables[i]] + kc_weight_mbon[i, indices]

        # kc_activity1: max 01; kc_activity1: max 01
        self.img_mbon1 = (self.kc_activity1 @ self.kc_mbon_weight1)  # 离散 离散
        self.img_mbon2 = np.matmul(self.kc_activity2, self.kc_mbon_weight1)  # 连续 离散
        self.img_mbon3 = np.matmul(self.kc_activity2, self.kc_mbon_weight2)  # 连续 连续
        self.img_mbon4 = np.matmul(self.kc_activity3, self.kc_mbon_weight3)  # 连续 连续 学习率
        print('mbon')

    def plot_maxMB(self, images, lables, max_5_mbons):
        plt.axes().get_xaxis().set_visible(False)  # 隐藏x坐标轴
        plt.axes().get_yaxis().set_visible(False)  # 隐藏y坐标轴
        lables = lables
        max_5_mbons = max_5_mbons

        for i in range(len(images)):
            plt.subplot(1, len(images), i + 1)
            plt.imshow(images[i].reshape((128, 128, 3)) / 255)
            xlabel = 'ori_angle' + str(lables[i] * 5)
            plt.xlabel(xlabel)
            title = 'max_5_mbons' + str(max_5_mbons[i] * 5)
            plt.title(title)
        plt.show()

    def predict(self, images, lables, images_nonave):
        right_angle2 = 0
        right_angle3 = 0
        right_angle4 = 0
        max_5_mbons = []
        for img_id in range(images.shape[0]):

            max_5_mbon = np.zeros((3, 5), dtype=int)

            # 方法3
            max_indices2 = np.argsort(self.img_mbon2[img_id])[-1:]
            max_5_mbon[0] = np.argsort(self.img_mbon2[img_id])[-5:][::-1]  # 升序变降序

            # 方法2
            max_indices3 = np.argsort(self.img_mbon3[img_id])[-1:]
            max_5_mbon[1] = np.argsort(self.img_mbon3[img_id])[-5:][::-1]

            # 方法1
            max_indices4 = np.argsort(self.img_mbon4[img_id])[-1:]
            max_5_mbon[2] = np.argsort(self.img_mbon4[img_id])[-5:][::-1]

            # if img_id % 5 == 0:
            #     self.plot_maxMB(images_nonave[img_id-5:img_id], lables[img_id-5:img_id], max_5_mbons[img_id-5:img_id])
            # max_5_mbons.append(max_5_mbon)

            if lables[img_id] in max_indices2:  # 方法3
                right_angle2 = right_angle2 + 1

            if lables[img_id] in max_indices3:  # 方法2
                right_angle3 = right_angle3 + 1

            if lables[img_id] in max_indices4:  # 方法1
                right_angle4 = right_angle4 + 1
            # print('end')

        acc = np.zeros(3)

        acc[0] = right_angle4 / images.shape[0]  # KC浮点数-Syn浮点数（权重）-MBON浮点数 方法1

        acc[1] = right_angle3 / images.shape[0]  # KC浮点数-Syn浮点数（直接相加）-MBON浮点数 方法2

        acc[2] = right_angle2 / images.shape[0]  # KC浮点数-Syn二进制-MBON浮点数  方法3
        return acc

    def lifelong_predict(self, i, images, lables, images_nonave):
        acc1 = np.zeros(i)
        acc2 = np.zeros(i)
        acc3 = np.zeros(i)
        for id in range(i):
            right_angle2 = 0
            right_angle3 = 0
            right_angle4 = 0
            start_index = id * 72
            end_index = min((id + 1) * 72, images.shape[0])
            for img_id in range(start_index, end_index):
                max_indices2 = np.argsort(self.img_mbon2[img_id])[-1:]

                max_indices3 = np.argsort(self.img_mbon3[img_id])[-1:]

                max_indices4 = np.argsort(self.img_mbon4[img_id])[-1:]

                if lables[img_id] in max_indices2:
                    right_angle2 = right_angle2 + 1

                if lables[img_id] in max_indices3:
                    right_angle3 = right_angle3 + 1

                if lables[img_id] in max_indices4:
                    right_angle4 = right_angle4 + 1

            acc1[id] = right_angle4 / 72  # KC浮点数-Syn浮点数（权重）-MBON浮点数

            acc2[id] = right_angle3 / 72  # KC浮点数-Syn浮点数（直接相加）-MBON浮点数

            acc3[id] = right_angle2 / 72  # KC浮点数-Syn二进制-MBON浮点数
        return acc1, acc2, acc3

    def predict_accnums(self, images, lables, images_nonave):
        right_angle2 = 0
        right_angle3 = 0
        right_angle4 = 0
        max_5_mbons = []

        for img_id in range(images.shape[0]):
            max_5_mbon = np.zeros((3, 5), dtype=int)
            a = self.img_mbon2[img_id]
            # 方法3
            max_indices2 = np.argsort(self.img_mbon2[img_id])[-1:]
            max_5_mbon[0] = np.argsort(self.img_mbon2[img_id])[-5:][::-1]  # 升序变降序

            # 方法2
            max_indices3 = np.argsort(self.img_mbon3[img_id])[-1:]
            max_5_mbon[1] = np.argsort(self.img_mbon3[img_id])[-5:][::-1]

            # 方法1
            max_indices4 = np.argsort(self.img_mbon4[img_id])[-1:]
            max_5_mbon[2] = np.argsort(self.img_mbon4[img_id])[-5:][::-1]

            if lables[img_id] in max_indices2:  # 方法3
                right_angle2 = right_angle2 + 1

            if lables[img_id] in max_indices3:  # 方法2
                right_angle3 = right_angle3 + 1

            if lables[img_id] in max_indices4:  # 方法1
                right_angle4 = right_angle4 + 1

        acc = np.zeros(3)

        acc[0] = right_angle4

        acc[1] = right_angle3

        acc[2] = right_angle2
        return acc


def single_test(hash_length, embedding_size, training_data, testing_data, sampling_ratio,
                all_expriments, lables, images_nonave, if_center_data=True):
    seed = hash_length * embedding_size * sampling_ratio
    random.seed(seed)
    np.random.seed(int(seed))
    model = {}
    functions = {'LSH': LSH, 'Fly': FlyLSH}
    sample_dim = training_data.shape[1]
    for expriment in all_expriments:  # ['Fly', 'FlylshDevelop', 'FlylshDevelopThreshold','FlylshDevelopThresholdRandomChoice']
        # print(expriment)
        if expriment == 'LSH':
            model[expriment] = LSH(sample_dim, hash_length)  # model[expriment]变成了LSH对象，后面可以调用LSH的方法
            model[expriment].generate_weights()
            model[expriment].hashing(testing_data, center_data=if_center_data)

        else:
            # model[Fly] = FlyLSH(sample_dim, hash_length=2, sampling_ratio = 0.10, embedding_size=2*20)
            # model[FlylshDevelop] = FlyLSHDevelop(sample_dim, hash_length=2, sampling_ratio = 0.10, embedding_size=2*20)
            model[expriment] = functions[expriment](sample_dim, hash_length, sampling_ratio, embedding_size)
            model[expriment].generate_weights(training_data)
            model[expriment].PNtoKC(testing_data,  center_data=if_center_data)
            model[expriment].KCtoMBON(testing_data, lables, center_data=if_center_data)

        acc = model[expriment].predict(testing_data, lables, images_nonave)
        return acc


def lifelong_test(hash_length, embedding_size, training_data, testing_data, sampling_ratio,
                  all_expriments, lables, images_nonave, if_center_data=True):
    seed = hash_length * embedding_size * sampling_ratio
    random.seed(seed)
    np.random.seed(int(seed))
    model = {}
    functions = {'LSH': LSH, 'Fly': FlyLSH}
    sample_dim = training_data.shape[1]
    for expriment in all_expriments:  # ['Fly', 'FlylshDevelop', 'FlylshDevelopThreshold','FlylshDevelopThresholdRandomChoice']
        # print(expriment)
        if expriment == 'LSH':
            model[expriment] = LSH(sample_dim, hash_length)  # model[expriment]变成了LSH对象，后面可以调用LSH的方法
            model[expriment].generate_weights()
            model[expriment].hashing(testing_data, center_data=if_center_data)

        else:
            accs1 = []
            accs2 = []
            accs3 = []
            for i in range(0, 87):  # 13 #87
                # acc = np.zeros((i + 1, 3))
                images = training_data[:72 * (i + 1), :]
                lables_ = lables[:72 * (i + 1)]
                model[expriment] = functions[expriment](sample_dim, hash_length, sampling_ratio, embedding_size)
                model[expriment].generate_weights(images)
                model[expriment].PNtoKC(images, center_data=if_center_data)
                model[expriment].KCtoMBON(images, lables, center_data=if_center_data)
                acc1, acc2, acc3 = model[expriment].lifelong_predict(i + 1, images, lables_, images_nonave)
                accs1.append(acc1)
                accs2.append(acc2)
                accs3.append(acc3)
                print(i)

                # for j in range(i+1):
                #     images_j = images[72 * j:72 * (j + 1), :]
                #     lables_j = lables_[72 * j:72 * (j + 1)]
                # for a in range(72):
                #     plt.imshow(images_j[a].reshape((128, 128, 3)) / 255)
                #     plt.show()

                #     acc[j] = model[expriment].lifelong_predict(j, images_j, lables_j, images_nonave)
                # accs.append(acc)
            result1 = np.zeros((87, 87))  # ((100, 100))#((87, 87))#((13, 13))
            for i, row in enumerate(accs1):
                result1[i][:len(row)] = row
            result2 = np.zeros((87, 87))
            for i, row in enumerate(accs2):
                result2[i][:len(row)] = row
            result3 = np.zeros((87, 87))
            for i, row in enumerate(accs3):
                result3[i][:len(row)] = row
            np.savetxt('./saved_csv/coil_del/longlife_sim_me1.csv', result1, delimiter=',')
            np.savetxt('./saved_csv/coil_del/longlife_sim_me2.csv', result2, delimiter=',')
            np.savetxt('./saved_csv/coil_del/longlife_sim_me3.csv', result3, delimiter=',')

            # with open("./saved_csv/lifelong_accs_ave.txt", "w") as file:
            #     for item in accs:
            #         file.write(str(item) + "\n")
        # np.savetxt('longlife_accs.txt', accs, delimiter=',')
        return result3


def hugeimage_test(hash_length, embedding_size, training_data, testing_data, sampling_ratio,
                   all_expriments, lables, images_nonave, if_center_data=True):
    seed = hash_length * embedding_size * sampling_ratio
    random.seed(seed)
    np.random.seed(int(seed))
    model = {}
    functions = {'LSH': LSH, 'Fly': FlyLSH}
    sample_dim = training_data.shape[1]
    for expriment in all_expriments:  # ['Fly', 'FlylshDevelop', 'FlylshDevelopThreshold','FlylshDevelopThresholdRandomChoice']
        # print(expriment)
        if expriment == 'LSH':
            model[expriment] = LSH(sample_dim, hash_length)  # model[expriment]变成了LSH对象，后面可以调用LSH的方法
            model[expriment].generate_weights()
            model[expriment].hashing(testing_data, center_data=if_center_data)

        else:
            # model[Fly] = FlyLSH(sample_dim, hash_length=2, sampling_ratio = 0.10, embedding_size=2*20)
            # model[FlylshDevelop] = FlyLSHDevelop(sample_dim, hash_length=2, sampling_ratio = 0.10, embedding_size=2*20)
            accnums = np.zeros(3)
            for i in range(0, 72001, 7200):
                images = training_data[7200 * i:7200 * (i + 1), :]
                lables_ = lables[7200 * i: 7200 * (i + 1)]
                model[expriment] = functions[expriment](sample_dim, hash_length, sampling_ratio, embedding_size)
                model[expriment].generate_weights(images)
                model[expriment].PNtoKC(images, center_data=if_center_data)
                model[expriment].KCtoMBON(images, lables_, center_data=if_center_data)
                accnum = model[expriment].predict_accnums(testing_data, lables, images_nonave)
                accnums = accnums + accnum
                print(accnum)
            acc = accnums / images.shape[0]
        return acc


if __name__ == '__main__':
    # import tensorflow_datasets as tfds
    # dataset = tfds.load("coil100", split=tfds.Split.TRAIN, as_supervised=True)
    max_index = 10000
    sampling_ratio = 0.10
    nnn = 200  # number of nearest neighbours to compare, 2% of max_index as in paper
    hash_lengths = [512]  # 2, 32, 64, 128, 256,
    number_of_tests = 1
    ratio = 20
    result_root_path = "./Results/NoPreprocessing/"  # "./Results/main/"
    lifelong = False
    hugeimge = False
    sorted_data = False
    number_of_process = 1

    threshold = 10
    if_center_data = False
    all_MAPs = {}


    images_sim_all = np.loadtxt('./saved_csv/coil_image_del.csv',
                                delimiter=',')  # , max_rows=144 , max_rows=936 , max_rows=14800
    angles = np.loadtxt('./saved_csv/coil_lable_del.csv', dtype=int, delimiter=',')  # , max_rows=144
    lables_all = angles // 5  # sim

    # lables_all = angles #coil-ori aloi
    print('loaded')
    images_nonave = images_sim_all
    images_ave = (images_sim_all - np.mean(images_sim_all, axis=1)[:, None])
    # 5.初始化准确率列表
    all_expriments = ['Fly']

    img_accs = []  # 87
    img_nonaccs = []
    img_aveaccs = []

    if lifelong:
        for hash_length in hash_lengths:  # k
            print('life-long')
            embedding_size = int(
                ratio * hash_length)  # int(10*input_dim) #20k or 10d  #20*[2, 4, 8, 12, 16, 20, 24, 28, 32]
            acc = lifelong_test(hash_length, embedding_size, images_nonave, images_nonave, sampling_ratio,
                                all_expriments, lables_all, images_nonave, if_center_data)
            print(acc[:, :10])
            img_accs.append(acc)

    else:
        for hash_length in hash_lengths:
            print(hash_length)  # k

            embedding_size = int(20 * hash_length)  # int(10*input_dim) #20k or 10d
            acc_ave = single_test(hash_length, embedding_size, images_ave, images_ave, sampling_ratio,
                                  all_expriments, lables_all, images_nonave, if_center_data)
            acc_nonave = single_test(hash_length, embedding_size, images_nonave, images_nonave, sampling_ratio,
                                     all_expriments, lables_all, images_nonave, if_center_data)

            img_nonaccs.append(acc_nonave)
            print(acc_ave)
            print(acc_nonave)
            img_aveaccs.append(acc_ave)
            np.savetxt('./saved_csv/coil_del/reset_acc_ave.txt', acc_ave, delimiter=',')
            np.savetxt('./saved_csv/coil_del/reset_acc_nonave.txt', acc_nonave, delimiter=',')

