# -*- coding: utf-8 -*-
"""

"""
from sklearn.decomposition import PCA, KernelPCA
from sklearn.datasets import fetch_openml
from sklearn.svm import LinearSVC
from torch.nn import init
import scipy.io as sio
import torch.nn.functional as F
from sklearn import datasets
from sklearn.cluster import KMeans
from sklearn.metrics import normalized_mutual_info_score
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import TensorDataset, DataLoader 
from torch import nn
from torch import optim
from sklearn.preprocessing import Normalizer, OneHotEncoder
from sklearn.preprocessing import StandardScaler,MinMaxScaler
from sklearn.neighbors import KNeighborsClassifier 
from sklearn.neighbors import NearestNeighbors
from myclassifiers import wassersteinclassifier
from sklearn.datasets import load_svmlight_file
#from kernelLDA1 import kernelLDA
import time
 
def onetrainC(X_train,cc,Y_traint,kn1, net,net1,lambda1,lambda2,device):
    XX_train=X_train
    optimizer=optim.Adam(net.parameters(), weight_decay=0.001)
    #scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=10)
    optimizer1=optim.Adam(net1.parameters(),weight_decay=0.001)
    train_losses=[]
    test_losses=[]
    n1=len(X_train)
    aaaa1=time.time()
   # weight=torch.zeros(n1,cc)

    #kn1=3
    label1=np.unique(Y_traint)
    nnn=np.zeros((1,len(label1)))
    XXX_train=[]
    for j in label1:
        selected=XX_train[Y_traint==j]
        XXX_train.append(selected)
        sn1=len(selected)
        nnn[0,j]=sn1
    nnn1=np.min(nnn)
    kn2 = np.floor(np.min(np.array([kn1, nnn1])))
    kn1=kn2.astype("int")
    
    
    
    neigh = NearestNeighbors(n_neighbors=kn1)
    qq=np.zeros((n1,kn1-1))
    index4=np.zeros((n1,kn1-1))
    weight1=np.zeros((n1,kn1-1))
    
        
    XXX_train=np.concatenate(XXX_train,axis=0)
    X_train=torch.tensor(XXX_train, dtype=torch.float32)
    iii=0
    for j in label1:
        selected=XX_train[Y_traint==j]
        sn1=len(selected)
        neigh.fit(selected)
        for ii in range(sn1):
            selected1=selected[ii]
            rng=neigh.kneighbors(selected1.reshape((1,-1)))
            index1=np.asarray(rng[1][0])
            distance=np.asarray(rng[0][0])
            distance1=distance[1:kn1]
            distance1=1./(1+np.power(distance1,2))
            qq[iii,:]=distance1/np.sum(distance1)
            iii=iii+1
        
        
        #X_train(selected)
        #mean_vec[j]=np.mean(XX_train[Y_traint==j], axis=0)
    #mean_vec=torch.tensor(mean_vec,dtype=torch.float32)
    #enc=OneHotEncoder()
   # Y_traint1=Y_traint.reshape(-1,1)
    #enc.fit(Y_traint1)
    #tempdata=enc.transform(Y_traint1).toarray()
    #pp=torch.tensor(tempdata,dtype=torch.int64)
    #lambda1=0.01*100
        #net.train()
        
    net= net.to(device)
    net1=net1.to(device)   
    X_train=X_train.to(device)   
    def logsumexp(x):
        c = x.max()
        return c + np.log(np.sum(np.exp(x - c)))      
    for i in range(5):   
        with torch.no_grad():
            #y=net(mean_vec)
            X_train1=net(X_train)
        
            
        XX_train=X_train1.cpu().detach().numpy()
        iii=0
        snn1=0
        for j in label1:
            selected=XX_train[Y_traint==j]
            sn1=len(selected)
            neigh.fit(selected)
            for ii in range(sn1):
                selected1=selected[ii]
                rng=neigh.kneighbors(selected1.reshape((1,-1)))
                index1=np.asarray(rng[1][0])
                index2=snn1+index1[1:kn1]
                
                distance=np.asarray(rng[0][0])
                distance1=distance[1:kn1]
                distance1=-np.power(distance1,2)/lambda1+np.log(qq[iii,:]+0.00000000001)
                weighttemp=np.exp(distance1-logsumexp(distance1))
                index4[iii,:]=index2;
                weight1[iii,:]=weighttemp
                iii=iii+1
            snn1=snn1+sn1
        weight2=torch.tensor(weight1, dtype=torch.float32)        
        index5=torch.tensor(index4, dtype=torch.int64)
        pp=qq
        pp1=torch.tensor(pp, dtype=torch.float32)    
            
       
        
        weight2=weight2.to(device)
        
        index5=index5.to(device)
         
        #mean_vec=mean_vec.to(device)
        
        pp1=pp1.to(device)
        
        # weight2 is update weight
        
        for j in range(10):
            net.train()
            output1=torch.zeros_like(weight2)
            output1=output1.to(device)
           # yz=net(mean_vec)
            X_train1=net(X_train)
            data1=net1(X_train1) 
            for ss in range(n1):
                index6=index5[ss,:]
                temp=X_train1[index6,:]
                output11=torch.cdist(X_train1[ss,:].reshape(-1,1).t(),temp,p=2)
            
                output1[ss,:]=torch.squeeze(torch.pow(output11,2))
            #out2=torch.cdist(yz,yz,p=2)
            sum1=torch.zeros((1,))
            #yz1=torch.mean(yz,dim=0)
            
           
            dotdiv1=torch.div(weight2,(pp1+0.000000001))
            a1=weight2*output1
            a2=weight2*torch.log(dotdiv1+0.000000001)
            #a3=torch.pow(out2,2)
            a4=torch.norm(data1-X_train)
            #a5=torch.pow(a4,2)
            loss=lambda1*a2.sum()+a1.sum()+lambda2*a4
            #loss=lambda1*a2.sum()+a1.sum()-0*1000*sum1/(cc)
            optimizer.zero_grad()
            optimizer1.zero_grad()
            loss.backward()
            optimizer.step()
            optimizer1.step()
            #scheduler.step()
            
        
            train_losses.append(loss.item())