import numpy as np
import scipy
import random
from sklearn.metrics import accuracy_score


# 找到近邻矩阵
def refine(X, K=3):
    # 首先计算X的距离矩阵
    dist_matrix = np.zeros((X.shape[0], X.shape[0]))
    adj = np.zeros((X.shape[0], X.shape[0]))
    for i in range(X.shape[0]):
        for j in range(X.shape[0]):
            dist_matrix[i, j] = np.linalg.norm(X[i] - X[j])
    # 然后分别为每个实例找到最近的K个邻居（不包含自身）
    for i in range(X.shape[0]):
        neighbors = np.argsort(dist_matrix[i])[1:K+1]
        for k in range(K):
            adj[i, neighbors[k]] = 1
            adj[neighbors[k], i] = 1
    adj = np.array(adj)
    return adj


class MA3S:
    def __init__(self, X, y, classnum, trueL=None, K=3):
        self.X = X
        self.y = y
        self.numClass = classnum
        self.K = K

        self.numInstance = X.shape[0]
        self.uncertainties = np.zeros(self.numInstance)

        if trueL == None:
            self.L = []
            self.Phi = []
            for i in range(self.numInstance):
                self.L.append([])
                self.Phi.append(np.ones(self.numClass))
                self.uncertainties[i] = 0.5
        else:
            self.L = trueL
            self.Phi = []
            for i in range(self.numInstance):
                self.Phi.append(np.ones(self.numClass))
                for j in self.L[i]:
                    self.Phi[i][j] += 1
                self.uncertainty(i)

        self.A = refine(self.X, self.K)


    def uncertainty(self, i):
        indexs = np.argsort(self.Phi[i])
        Nm = self.Phi[i][indexs[-1]]
        Ns = self.Phi[i][indexs[-2]]
        self.uncertainties[i] = scipy.special.betainc(Nm, Ns, 0.4)
        return self.uncertainties[i]


    def select(self):
        if self.K != 0:
            temp = np.zeros(self.numInstance)
            for i in range(self.numInstance):
                temp[i] = self.uncertainties[i]
                sumtemp = np.sum(self.A[i])
                for j in range(self.numInstance):
                    temp[i] += 1.0 / sumtemp * self.A[i][j] * self.uncertainties[j]
        else:
            temp = self.uncertainties
        index = np.argsort(temp)[-1]
        return index


    def update(self, i, label):
        self.Phi[i][label] += 1
        self.uncertainty(i)


    def annotate(self, i):
        p = random.uniform(0.55, 0.75)
        temp = None
        if random.random() < p:
            temp = self.y[i]
        else:
            temp = random.randint(0, self.numClass-1)
            while temp == self.y[i]:
                temp = random.randint(0, self.numClass-1)
        self.L[i].append(temp)
        return temp


    def total(self, T):
        count = []
        for t in range(T):
            index = self.select()
            temp = self.annotate(index)
            count.append(index)
            self.update(index, temp)
            if t % 100 == 0:
                print(self.MV(), end='\t')
        print('')
        return count


    def MV(self):
        tempX = []
        tempy = []
        truey = []
        for i in range(self.numInstance):
            if(len(self.L[i]) > 0):
                tempX.append(self.X[i])
                count = np.zeros(self.numClass)
                for j in range(len(self.L[i])):
                    count[self.L[i][j]] += 1
                index = np.argmax(count)
                tempy.append(index)
                truey.append(self.y[i])
        acc = accuracy_score(tempy, truey)
        return acc
