from copy import deepcopy

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from MA3S import *
from GRR import *
from ENTROPY import *
from LU import *
from DPP import *
from MU import *
from LMU import *
from ALFC import *
from matplotlib.colors import Normalize

rng = np.random.default_rng(7)

mpl.rcParams["font.family"] = "Times New Roman" # 或 ["Times New Roman"]
mpl.rcParams["axes.unicode_minus"] = False # 让负号正常显示
mpl.rcParams["font.size"] = 20 # 全局字号

def make_imbalanced_moons(n_inner=90, n_outer=90, noise=0.06, random_state=1):
    rng = np.random.default_rng(random_state)

    # 上弧（inner）
    theta1 = rng.uniform(0, np.pi, size=n_inner)
    x1 = np.column_stack([np.cos(theta1), np.sin(theta1)])
    x1 += rng.normal(scale=noise, size=x1.shape)

    # 下弧（outer）
    theta2 = rng.uniform(0, np.pi, size=n_outer)
    x2 = np.column_stack([1 - np.cos(theta2),  -np.sin(theta2) - 0.5])
    x2 += rng.normal(scale=noise, size=x2.shape)

    X = np.vstack([x1, x2])
    y = np.concatenate([np.zeros(n_inner, int), np.ones(n_outer, int)])

    # 继续生成模拟标注
    crowdL  = []
    for i in range(X.shape[0]):
        crowdL.append([])
        # 第一类：小对【+】
        if i<30:
            crowdL[i] = [y[i]]
        # 第二类：大对【++++】
        elif i <60:
            crowdL[i] = [y[i], 1-y[i]]
        elif i<90:
            crowdL[i] = [y[i],y[i], 1-y[i]]
        elif i<120:
            crowdL[i] = [y[i],y[i],1-y[i],1-y[i]]
        elif i<150:
            crowdL[i] = [y[i],y[i],y[i],y[i], 1-y[i]]
        else:
            crowdL[i] = [y[i],y[i],y[i],y[i], 1-y[i], 1-y[i]]
    return X, y, crowdL


if __name__ == '__main__':
    numClass = 2
    # X, y = make_gaussian_mixture_2d(n_classes=numClass)
    X, y, crowdL = make_imbalanced_moons()

    strategy = MA3S(X=X, y=y, classnum=numClass, trueL=crowdL, K=3)

    count = strategy.total(10)

    # temp = [0,0,0,0,0,0]
    # total = []
    # total.append(deepcopy(temp))
    #
    # print(count)
    #
    # for i in range(len(count)):
    #     if count[i] < 30: temp[0] +=1
    #     elif count[i] < 60: temp[1] +=1
    #     elif count[i] < 90: temp[2] +=1
    #     elif count[i] < 120: temp[3] +=1
    #     elif count[i] < 150: temp[4] +=1
    #     else:
    #         temp[5] +=1
    #     if (i+1)% 30 == 0:
    #         # print(list(temp))
    #         total.append(deepcopy(temp))
    #
    # print(total)

    temp = []
    for i in range(X.shape[0]):
        if i in count:
            temp.append(1)
        else:
            temp.append(0)
    temp = np.array(temp)

    print(sum(temp))

    # 可视化一： 示例选择结果
    plt.figure(figsize=(8, 6))
    plt.scatter(X[temp==0, 0], X[temp==0, 1], s=80, facecolors='none', edgecolors='grey', linewidths=0.5, label='unannotated')
    # plt.scatter(X[temp==1, 0], X[temp==1, 1], s=20, facecolors='none', edgecolors='none', linewidths=0.5,)
    plt.scatter(X[temp==1, 0], X[temp==1, 1], s=80, facecolors='#08316D', edgecolors='#08316D', linewidths=0.5, label='annotated')
    plt.legend(loc='best', fontsize=16,)

    plt.savefig("fig3_1.pdf", bbox_inches="tight", pad_inches=0.01, dpi=300)

    plt.show()


    # 可视化二：不确定性下降
 #    norm = Normalize(vmin=0, vmax=0.4)
 #    cmap = plt.cm.Blues  # 也可选 'plasma', 'magma', 'inferno', 'cividis'
 #
 #    temp_s = MA3L(X=X, y=y, classnum=numClass, trueL=crowdL)
 #
 #    # 取得唯一不确定性值并排序
 #    uniq_u = np.unique(temp_s.uncertainties)
 #    # 如果希望从小到大排序，保持如下；若想从大到小，请使用 uniq_u = uniq_u[::-1]
 #
 #
 #
 #    fig, ax = plt.subplots(figsize=(8, 6), dpi=300)
 #    # sc = plt.scatter(X[:, 0], X[:, 1], c=temp_s.uncertainties, cmap=cmap, norm=norm, edgecolors='grey', linewidths=0.5, s=80)
 #
 #    labels = ['$\mathcal{I}_5$','$\mathcal{I}_6$','$\mathcal{I}_1$','$\mathcal{I}_3$','$\mathcal{I}_4$','$\mathcal{I}_2$']
 #
 #    for i, u in enumerate(uniq_u):
 #        mask = (temp_s.uncertainties == u)
 #        if not np.any(mask):
 #            continue
 #        sc = plt.scatter(X[mask, 0], X[mask, 1], c=np.full(mask.sum(), u), cmap=cmap, norm=norm, edgecolors='grey', linewidths=0.5, s=80, label=labels[i])
 #
 #    # 颜色条用于读取 y 数值
 #    cbar = plt.colorbar(sc, pad=0.02)
 #    cbar.set_label('Uncertainty')
 #
 #    legend = ax.legend(loc="best", ncols=1, frameon=True, handlelength=2, columnspacing=1.6, fontsize=16)
 #
 #    plt.tight_layout()
 #    plt.savefig("fig2_1.pdf", bbox_inches="tight", pad_inches=0.01, dpi=300)
 #    plt.show()
 #