 #from sklearn.neighbors import NearestNeighbors
from sklearn.datasets import fetch_openml
from sklearn.svm import LinearSVC
from torch.nn import init
import scipy.io as sio
import torch.nn.functional as F
from sklearn import datasets
from sklearn.cluster import KMeans
from sklearn.metrics import normalized_mutual_info_score
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import TensorDataset, DataLoader,Sampler 
from torch import nn
from torch import optim
from sklearn.preprocessing import Normalizer, OneHotEncoder
from sklearn.preprocessing import StandardScaler,MinMaxScaler
from sklearn.neighbors import KNeighborsClassifier 
from sklearn.neighbors import NearestNeighbors,NearestCentroid
import time
from myclassifiers1 import onetrainC
from collections import Counter
import random
class MySampler(Sampler):
    """
        自定义Sampler，在__iter__函数中定义indices的生成方式，也叫生成顺序
    """

    def __init__(self, labels):
        self.labels = (labels)
        self.image_ids = []

    def __iter__(self):
        """
            在每个batch中包含的每个类别的数量相等
        :return:
        """
        indices = []
        counter = Counter(self.labels)
        # 统计数据量最多的类别
        most_common = counter.most_common(1)[0][1]
        # 统计每张图片在filenames这个列表中对应的索引编号
        for c in range(len(counter)):
            indices.append(np.where(self.labels == c)[0].tolist())

        # 所有类别通过复制的方式与最多的类别对齐
        for indice in indices:
            if len(indice) < most_common:
                indice.extend(random.choices(indice, k=most_common - len(indice)))
            random.shuffle(indice)

        # 依次从所有类别中分别取一张图片组成batch
        for ids in zip(*indices):
            self.image_ids.extend(list(ids))

        return iter(self.image_ids)

    def __len__(self):
        return len(self.image_ids)
def wassersteinclassifier(X_train, Y_traint, X_test,  Y_test,cc, lambda1, lambda2, kn1, ba):
    device = torch.device("cpu")    # 使用cpu训练
    device = torch.device("cuda") 
        
    XX_train=X_train
    XX_test=X_test
    X_train=torch.tensor(X_train, dtype=torch.float32)
    Y_train=torch.tensor(Y_traint,dtype=torch.int64)
    X_test=torch.tensor(X_test, dtype=torch.float32)
    #Y_test=torch.tensor(Y_test,dtype=torch.int64)
    k1=X_train.shape[1]
    k=500
    #cc=3
    
    

    net=nn.Sequential(
        nn.Linear(k1,k),
        nn.ReLU(),
        #nn.Dropout(0.1),
        nn.Linear(k,k),
        nn.ReLU(),
        #nn.BatchNorm1d(k),
        #nn.Dropout(0.1),
        nn.Linear(k,2000),
        nn.ReLU(),
        nn.Linear(2000,round(cc)),
        #nn.Sigmoid()
        )
    net1=nn.Sequential(
        nn.Linear(round(cc),2000),
        nn.ReLU(),
        #nn.Dropout(0.5),
        nn.Linear(2000,k),
        nn.ReLU(),
        #nn.BatchNorm1d(k),
        #nn.Dropout(0.5),
        nn.Linear(k,k),
        
        
        nn.ReLU(),
        nn.Linear(k,k1),
        #nn.Sigmoid()
        
        )
    #kn1=150
    aaaa1=time.time()
    #X=torch.tensor(X, dtype=torch.float32)
   #X_train=F.normalize(X_train,dim=1)
    #Y=torch.tensor(Y,dtype=torch.int64)
    Train_DS = TensorDataset(X_train,Y_train)
    batch_size1 = ba
    train_loader = torch.utils.data.DataLoader(
                    dataset=Train_DS,
                    batch_size=batch_size1,
                    shuffle=False,
                    sampler=MySampler(Y_traint))
    for epoch in range(4):
        for i,data in enumerate(train_loader):
            x, label = data
            n1=len(x)
            Y_traint=label.detach().numpy()
            #Poutput=euclidean_distances(y3,aa)

            #Poutput1=np.power(Poutput,2)
            #PP1=1/(Poutput1+1)
            #PP1=PP1/PP1.sum(axis=1,keepdims=1)
            #PP1=np.ones((150,3))/3
            #aa=torch.tensor(aa)
            #aa=X_train[0:3,:]
            #aa1=torch.eye(k1)
            #aa=aa1[0:cc,:]
            #aaa1=aa
            #y=aa   
            #weight2=onetrain(x,cc,aa,n1,PP1,net,net1,updatecluster,lambda1,lambda2,device,loss1)
    #optimizer=optim.SGD(net.parameters(),lr=0.0001)
            onetrainC(x,cc,Y_traint,kn1, net,net1,lambda1,lambda2,device)
    aaaa2=time.time()
    score7=aaaa2-aaaa1    
    #plt.plot(train_losses)
    #net.eval()
    #y1=net(aa)
    net.eval()
    X1=X_test
    #X_train=F.normalize(X_train,dim=1)
    #Y1=torch.tensor(Y1,dtype=torch.int64)
    X1=X1.to(device)
    X_train=X_train.to(device)
    X11=net(X1)
    X12=net(X_train)
    y3=X11.cpu().detach().numpy()
    y4=X12.cpu().detach().numpy()

    #X_train1=F.normalize(X_train1,dim=1)
    #X_train2=F.normalize(X_train2,dim=1)


    clf = NearestCentroid()
    clf.fit(y4, Y_train)
    y_pred = clf.predict(y3)
    score1=np.mean(Y_test== y_pred)
    #data1 = min_max_scaler.fit_transform(data1)
    #data2 = min_max_scaler.fit_transform(data2)
    
    
    





    #print(score)
    return score1,score7