# -*- coding: utf-8 -*-
"""

"""
import torch
import torchvision
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler,MinMaxScaler
from sklearn import datasets
from sklearn.cluster import KMeans,SpectralClustering
from sklearn.metrics import normalized_mutual_info_score
from scipy import io as scio
#from mycluster4 import wassersteincluster
from newclassifiers import wassersteinclassifier
c=np.load("trainCIFA.npz")
c1=np.load("testCIFA.npz")
c=np.load("FatrainMI.npz")
c1=np.load("testMI.npz")
X=c1['X']
Y=c1['Y']
#train sammples
X1=c1['X2']
Y1=c1['Y1']
#testsample
cc=10


min_max_scaler = MinMaxScaler(feature_range=(0,1.0))
#X = min_max_scaler.fit_transform(X)


#model=KMeans(cc)
#model.fit(X)
#b=model.labels_
#a=normalized_mutual_info_score(Y,b)
#model=SpectralClustering(n_clusters=10, n_neighbors=10, affinity='nearest_neighbors')
#model.fit(X)
#b=model.labels_
#a1=normalized_mutual_info_score(Y,b)
cc=10
#cc1=[5,10,15,20,25,30,35,40,45,50,55]    
ba1=[200,500,1000,1500,2000,2500]  
cc1=ba1                                                                
aa=[0.1, 1, 10, 100, 1000]
bb=[1, 10,100, 1000,  10000]
ff=np.arange(5,155,10)
liang11=np.zeros(len(cc1))
liang12=np.zeros(len(cc1))
liang1=np.zeros((np.size(aa,0),np.size(aa,0)))
liang2=np.zeros((np.size(aa,0),np.size(aa,0)))
liang3=np.zeros((np.size(aa,0),np.size(aa,0)))
liang4=np.zeros((np.size(aa,0),np.size(aa,0)))
file_path = 'dataKFBT.mat'
 
# 调用save()函数保存为MATLAB格式的文件

#for i in range(1):
    #for j in range(5):
       # lambda1=aa[i]
        #lambda2=bb[j]
        #score1=wassersteinclassifier(X,Y,X1, Y1, cc, lambda1, lambda2,kn1)
        #print(score1)
        
        #liang1[i,j]=score1
ii=0
for i in cc1:
    score1,score7=wassersteinclassifier(X,Y,X1, Y1, cc, 0.1, 10,15,i)
    print(score1)
    liang11[ii]=score1
    liang12[ii]=score7
    print(score7)
    ii=ii+1
            
    scio.savemat(file_path, {'my_array': liang11, 'my_array1': liang12})