# -*- coding: utf-8 -*-
"""

"""

from torch import linalg as LA
from sklearn.datasets import fetch_openml
import torch.nn.functional as F
from torch.nn import init
import torch.nn.functional as F
from sklearn import datasets
from sklearn.cluster import KMeans
from utils import accuracy
from sklearn.metrics import normalized_mutual_info_score，adjusted_rand_score
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import TensorDataset, DataLoader 
from torch import nn
from torch import optim
from sklearn.preprocessing import Normalizer
from sklearn.preprocessing import StandardScaler,MinMaxScaler
import time
import scipy.io as sio
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.datasets import fetch_olivetti_faces
from sklearn.cluster import SpectralClustering
def wassersteincluster(X1,Y1,X,Y,cc, lambda1, lambda2):
    #device = torch.device("cpu")    # 使用cpu训练
    device = torch.device("cuda") 
    #min_max_scaler = StandardScaler()
    #min_max_scaler = MinMaxScaler(feature_range=(-1.0,1.0))
    #X = min_max_scaler.fit_transform(X)
    #X = Normalizer().fit(X) 

    #X=X[:,::4]
    #x_norm=np.linalg.norm(X,axis=1, keepdims=True)
    #X=X/x_norm

    X_train=torch.tensor(X, dtype=torch.float32)
    #X_train=F.normalize(X_train,dim=1)
    Y_train=torch.tensor(Y,dtype=torch.int64)
    k1=X_train.shape[1]
    k=500
    net=nn.Sequential(
        nn.Linear(k1,k,bias=False),
        nn.ReLU(),
        #nn.Dropout(0.1),
        nn.Linear(k,k,bias=False),
        nn.ReLU(),
        #nn.BatchNorm1d(k),
        #nn.Dropout(0.1),
        nn.Linear(k,2000,bias=False),
        nn.ReLU(),
        nn.Linear(2000,round(10)),
        #nn.Sigmoid()
        )
    net1=nn.Sequential(
        nn.Linear(round(10),2000,bias=False),
        nn.ReLU(),
        #nn.Dropout(0.5),
        nn.Linear(2000,k,bias=False),
        nn.ReLU(),
        #nn.BatchNorm1d(k),
        #nn.Dropout(0.5),
        nn.Linear(k,k,bias=False),
        
        
        nn.ReLU(),
        nn.Linear(k,k1,bias=False),
        #nn.Sigmoid()
        
        )
   

    
  

    #optimizer=optim.SGD(net.parameters(),lr=0.0001)
    
    #scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=10)
    train_losses=[]
    test_losses=[]
    #n1=len(X_train)

    #weight=torch.zeros(n1,cc)
    y3=X_train.detach().numpy()
    model=KMeans(cc)
    model.fit(y3)
    aa=model.cluster_centers_
    
    #PP1=np.ones((150,3))/3
    aa=torch.tensor(aa)
    #aa=X_train[0:3,:]
    #aa1=torch.eye(k1)
    #aa=aa1[0:cc,:]
    #aaa1=aa
    y=aa
    def logsumexp(x):
        c = x.max()
        return c + np.log(np.sum(np.exp(x - c)))
    def updatecluster(y3,aa,n1, cc, PP1):
        aa1=np.zeros((n1,cc))
        output=euclidean_distances(y3,aa)
        output1=np.power(output,2)
        #out2=euclidean_distances(y,y)
        output1=-output1/lambda1+np.log(PP1+0.0000001)
            
        for j in range(n1):
            aab=np.exp(output1[j]-logsumexp(output1[j]))
            aa1[j]=aab
       
        aa2=np.matmul(aa1.T,y3)
        dd=np.sum(aa1,axis=0)
        aa3=aa2/(dd[:,None]+0.000001)
        aa=aa3
        return aa

    def loss1(data,X_train,xx,yz,weight2,lambda1,n1, pp,lambda0=10):
        output=torch.cdist(xx,yz)
        output1=torch.pow(output,2)
        out2=torch.cdist(yz,yz)
        a1=weight2*output1
        a2=weight2*torch.log(weight2/pp+0.000001)
        a3=torch.pow(out2,2)
        a4=LA.matrix_norm(data-X_train)
        a5=torch.pow(a4,2)
        loss=a1.sum()+lambda1*a2.sum()+lambda0*a5
        loss=loss/n1
        return loss


     
    a1=time.time()
    
    X1=torch.tensor(X1, dtype=torch.float32)
   #X_train=F.normalize(X_train,dim=1)
    Y1=torch.tensor(Y1,dtype=torch.int64)
    Train_DS = TensorDataset(X1,Y1)
    batch_size = 2000
    train_loader = torch.utils.data.DataLoader(
                    dataset=Train_DS,
                    batch_size=2000,
                    shuffle=True)
    for epoch in range(2):
        for i,data in enumerate(train_loader):
            x, label = data
            n1=len(x)
            y3=x.detach().numpy()
            Poutput=euclidean_distances(y3,aa)

            Poutput1=np.power(Poutput,2)
            PP1=1/(Poutput1+1)
            PP1=PP1/PP1.sum(axis=1,keepdims=1)
            #PP1=np.ones((150,3))/3
            #aa=torch.tensor(aa)
            #aa=X_train[0:3,:]
            #aa1=torch.eye(k1)
            #aa=aa1[0:cc,:]
            #aaa1=aa
            #y=aa   
            weight2=onetrain(x,cc,aa,n1,PP1,net,net1,updatecluster,lambda1,lambda2,device,loss1)
        #net.eval()
        #with torch.no_grad():
            #y=net(aaa1)
    a2=time.time()
    a3=(a2-a1)    
    plt.plot(train_losses)
    #net.eval()
    #y1=net(aa)
    y1=aa
    net.eval()
    X1=torch.tensor(X1, dtype=torch.float32)
    #X_train=F.normalize(X_train,dim=1)
    Y1=torch.tensor(Y1,dtype=torch.int64)
    X1=X1.to(device)
    X11=net(X1)
    #X11=X11.cpu().detach()
    #dd=torch.cdist(X11,Y1)
   # y1=dd.cpu().detach().numpy()
    y2=Y1.detach().numpy()
    #temp1=weight2.cpu().detach().numpy()
    #temp2=np.argmax(temp1, axis=1)
    #ff=np.argmin(y1,axis=1)
    #a=normalized_mutual_info_score(temp2,y2)
    
    
  
    #y3=X_train.cpu().detach().numpy()
    #model=KMeans(cc)
    #model.fit(y3)
   # b=model.labels_
    #c=normalized_mutual_info_score(y2,b)
   

    y3=X11.cpu().detach().numpy()
    min_max_scaler = MinMaxScaler(feature_range=(0,1.0))
    y_pred = SpectralClustering(n_clusters=10, n_neighbors=10,affinity='nearest_neighbors').fit_predict(y3)
    #cf=normalized_mutual_info_score(Y,b)
    cf2=normalized_mutual_info_score(y2,y_pred)
    
    y3 = min_max_scaler.fit_transform(y3)
    model=KMeans(10)
    model.fit(y3)
    b=model.labels_
    d=normalized_mutual_info_score(y2,b)
    y_pred = SpectralClustering(n_clusters=10, n_neighbors=10,affinity='nearest_neighbors').fit_predict(y3)
    #cf=normalized_mutual_info_score(Y,b)
    cf1=normalized_mutual_info_score(y2,y_pred)
    cf2=adjusted_rand_score(y2,y_pred)
    cf3=accuracy(y2,y_pred)
    return a3,cf2,d,cf1,cf2,cf3
def onetrain(X_train,cc,aa,n1,PP1,net,net1,updatecluster,lambda1,lambda2,device,loss1):
    optimizer=optim.Adam(net.parameters(), weight_decay=0.001)
    optimizer1=optim.Adam(net1.parameters(), weight_decay=0.001)
    X_train=X_train.to(device)
    net= net.to(device)
    with torch.no_grad():
        #y=net(aa)
        X_train1=net(X_train)
    y3=X_train1.cpu().detach().numpy()    
    model=KMeans(cc)
    model.fit(y3)
    aa1=model.cluster_centers_
    #aa1=torch.randn((10,10))
    #aa1=torch.tensor(aa1)
    #lambda1=0.00001*10000
    for i in range(10):
       # lambda1=lambda1*0.99
        #net.train()
        with torch.no_grad():
            #y=net(aa)
            X_train1=net(X_train)
            
        
        y3=X_train1.cpu().detach().numpy()
        if i==0:
            y=aa1
            #y=torch.randn(cc,20)
        else:
            
            y=aa.cpu().detach().numpy()
            #y=aa.numpy()
        aa=updatecluster(y3,y,n1, cc,PP1)
      
        pp2=torch.tensor(PP1,dtype=torch.float32)
        pp3=pp2
        aa=torch.tensor(aa,dtype=torch.float32)
        y=aa
        y3=torch.tensor(y3,dtype=torch.float32)
        #output=torch.cdist(X_train1,y)
        output=torch.cdist(y3,y)
        output1=torch.pow(output,2)
        out2=torch.cdist(y,y)
        output1=-output1/lambda1+torch.log(pp2+0.0000001)
          #  
        weight=torch.logsumexp(output1,dim=1)
        #weight1=torch.sum(weight,dim=1)
        weight1=weight.repeat(cc,1)
        weight1=weight1.transpose(0,1)
        weight2=torch.exp(output1-weight1)
        #weight2=weight2
        # weight2 is update weight
        
        net= net.to(device)
        net1=net1.to(device)
        weight2=weight2.to(device)
        X_train=X_train.to(device)
        #X_train1=X_train1.to(device)
        aa=aa.to(device)
        pp3=pp3.to(device)
        for j in range(100):
            net.train()
            
            #yz=net(aa)
            yz=aa
            X_train1=net(X_train)
            data1=net1(X_train1) 
            loss=loss1(data1,X_train,X_train1,yz,weight2,lambda1,n1, pp3, lambda2)
            optimizer.zero_grad()
            optimizer1.zero_grad()
            loss.backward()
            optimizer.step()
            optimizer1.step()
            #scheduler.step()
    return weight2
        
            #train_losses.append(loss.item())