# -*- coding: utf-8 -*-
"""

"""

 #from sklearn.neighbors import NearestNeighbors
from sklearn.datasets import fetch_openml
from sklearn.svm import LinearSVC
from torch.nn import init
import scipy.io as sio
import torch.nn.functional as F
from sklearn import datasets
from sklearn.cluster import KMeans
from sklearn.metrics import normalized_mutual_info_score
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import TensorDataset, DataLoader 
from torch import nn
from torch import optim
from sklearn.preprocessing import Normalizer, OneHotEncoder
from sklearn.preprocessing import StandardScaler,MinMaxScaler
from sklearn.neighbors import KNeighborsClassifier 
from sklearn.neighbors import NearestNeighbors
import time
def wassersteinclassifier(X_train, X_test, Y_traint, Y_test,cc, lambda1, lambda2):
    device = torch.device("cpu")    # 使用cpu训练
    device = torch.device("cuda") 
        
    XX_train=X_train
    XX_test=X_test
    X_train=torch.tensor(X_train, dtype=torch.float32)
    Y_train=torch.tensor(Y_traint,dtype=torch.int64)
    X_test=torch.tensor(X_test, dtype=torch.float32)
    #Y_test=torch.tensor(Y_test,dtype=torch.int64)
    k1=X_train.shape[1]
    k=500
    #cc=3
    
    

    net=nn.Sequential(
        nn.Linear(k1,k),
        nn.ReLU(),
        #nn.Dropout(0.1),
        nn.Linear(k,k),
        nn.ReLU(),
        #nn.BatchNorm1d(k),
        #nn.Dropout(0.1),
        nn.Linear(k,2000),
        nn.ReLU(),
        nn.Linear(2000,round(cc)),
        #nn.Sigmoid()
        )
    net1=nn.Sequential(
        nn.Linear(round(cc),2000),
        nn.ReLU(),
        #nn.Dropout(0.5),
        nn.Linear(2000,k),
        nn.ReLU(),
        #nn.BatchNorm1d(k),
        #nn.Dropout(0.5),
        nn.Linear(k,k),
        
        
        nn.ReLU(),
        nn.Linear(k,k1),
        #nn.Sigmoid()
        
        )

    #optimizer=optim.SGD(net.parameters(),lr=0.0001)
    optimizer=optim.Adam(net.parameters(), weight_decay=0.01)
    scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=10)
    optimizer1=optim.Adam(net1.parameters(),weight_decay=0.001)
    train_losses=[]
    test_losses=[]
    n1=len(X_train)
    aaaa1=time.time()
    weight=torch.zeros(n1,cc)

    kn1=3
    label1=np.unique(Y_traint)
    neigh = NearestNeighbors(n_neighbors=kn1)
    qq=np.zeros((n1,kn1-1))
    index4=np.zeros((n1,kn1-1))
    weight1=np.zeros((n1,kn1-1))
    XXX_train=[]
    for j in label1:
        selected=XX_train[Y_traint==j]
        XXX_train.append(selected)
    XXX_train=np.concatenate(XXX_train,axis=0)
    X_train=torch.tensor(XXX_train, dtype=torch.float32)
    iii=0
    for j in label1:
        selected=XX_train[Y_traint==j]
        sn1=len(selected)
        neigh.fit(selected)
        for ii in range(sn1):
            selected1=selected[ii]
            rng=neigh.kneighbors(selected1.reshape((1,-1)))
            index1=np.asarray(rng[1][0])
            distance=np.asarray(rng[0][0])
            distance1=distance[1:kn1]
            distance1=1./(1+np.power(distance1,2))
            qq[iii,:]=distance1/np.sum(distance1)
            iii=iii+1
        
        
        #X_train(selected)
        #mean_vec[j]=np.mean(XX_train[Y_traint==j], axis=0)
    #mean_vec=torch.tensor(mean_vec,dtype=torch.float32)
    #enc=OneHotEncoder()
   # Y_traint1=Y_traint.reshape(-1,1)
    #enc.fit(Y_traint1)
    #tempdata=enc.transform(Y_traint1).toarray()
    #pp=torch.tensor(tempdata,dtype=torch.int64)
    #lambda1=0.01*100
        #net.train()
        
        
    for name, param in net.named_parameters():
        if 'weight' in name:
            #init.orthogonal_(param,gain=1)
            #init.xavier_uniform_(param,gain=1)
            init.xavier_normal_(param,gain=1)
            ##init.kaiming_uniform_(param)
            #init.normal_(param,mean=0,std=0.01)
            pass
            
    for name, param in net1.named_parameters():
        if 'weight' in name:
             ##init.orthogonal_(param,gain=1)
            #init.xavier_uniform_(param,gain=1)
            init.xavier_normal_(param,gain=1)
            ##init.kaiming_uniform_(param)
            #init.normal_(param,mean=0,std=0.01)
            pass        
    def logsumexp(x):
        c = x.max()
        return c + np.log(np.sum(np.exp(x - c)))      
    for i in range(2):   
        with torch.no_grad():
            #y=net(mean_vec)
            X_train1=net(X_train)
        
            
        XX_train=X_train1.cpu().detach().numpy()
        iii=0
        snn1=0
        for j in label1:
            selected=XX_train[Y_traint==j]
            sn1=len(selected)
            neigh.fit(selected)
            for ii in range(sn1):
                selected1=selected[ii]
                rng=neigh.kneighbors(selected1.reshape((1,-1)))
                index1=np.asarray(rng[1][0])
                index2=snn1+index1[1:kn1]
                
                distance=np.asarray(rng[0][0])
                distance1=distance[1:kn1]
                distance1=-np.power(distance1,2)/lambda1+np.log(qq[iii,:]+0.00000000001)
                weighttemp=np.exp(distance1-logsumexp(distance1))
                index4[iii,:]=index2;
                weight1[iii,:]=weighttemp
                iii=iii+1
            snn1=snn1+sn1
        weight2=torch.tensor(weight1, dtype=torch.float32)        
        index5=torch.tensor(index4, dtype=torch.int64)
        pp=qq
        pp1=torch.tensor(pp, dtype=torch.float32)    
            
       
        net= net.to(device)
        net1=net1.to(device)
        weight2=weight2.to(device)
        
        index5=index5.to(device)
        X_train=X_train.to(device)    
        #mean_vec=mean_vec.to(device)
        
        pp1=pp1.to(device)
        
        # weight2 is update weight
        
        for j in range(10):
            net.train()
            output1=torch.zeros_like(weight2)
            output1=output1.to(device)
           # yz=net(mean_vec)
            X_train1=net(X_train)
            data1=net1(X_train1) 
            for ss in range(n1):
                index6=index5[ss,:]
                temp=X_train1[index6,:]
                output11=torch.cdist(X_train1[ss,:].reshape(-1,1).t(),temp,p=2)
            
                output1[ss,:]=torch.squeeze(torch.pow(output11,2))
            #out2=torch.cdist(yz,yz,p=2)
            sum1=torch.zeros((1,))
            #yz1=torch.mean(yz,dim=0)
            
           
            dotdiv1=torch.div(weight2,(pp1+0.000000001))
            a1=weight2*output1
            a2=weight2*torch.log(dotdiv1+0.000000001)
            #a3=torch.pow(out2,2)
            a4=torch.norm(data1-X_train)
            #a5=torch.pow(a4,2)
            loss=lambda1*a2.sum()+a1.sum()+lambda2*a4
            #loss=lambda1*a2.sum()+a1.sum()-0*1000*sum1/(cc)
            optimizer.zero_grad()
            optimizer1.zero_grad()
            loss.backward()
            optimizer.step()
            optimizer1.step()
            #scheduler.step()
            
        
            train_losses.append(loss.item())
        #net.eval()
        #with torch.no_grad():
            #y=net(aaa1)
    aaaa2=time.time()
    score7=aaaa2-aaaa1    
    plt.plot(train_losses)
    #net.eval()
    #y1=net(aa)
    X_train1=net(X_train)
    X_test=X_test.to(device)
    X_train2=net(X_test)


    #X_train1=F.normalize(X_train1,dim=1)
    #X_train2=F.normalize(X_train2,dim=1)


    knn=KNeighborsClassifier(n_neighbors=1, p=2)
    data1=X_train1.cpu().detach().numpy()
    data2=X_train2.cpu().detach().numpy()
    min_max_scaler = MinMaxScaler()
    #data1 = min_max_scaler.fit_transform(data1)
    #data2 = min_max_scaler.fit_transform(data2)
    
    
    knn.fit(data1, Y_traint)
    score1=knn.score(data2, Y_test)





    #print(score)
    return score1