#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
distributional_min_sse is the main function for calculating the HWD problem (see Definition 1 in the paper).
"""

import torch
import time

class BadShapeError(Exception):
    pass

def rand_projections(dim, num_projections=1000):
    projections = torch.randn((num_projections, dim))
    projections = projections / torch.sqrt(torch.sum(projections ** 2, dim=1, keepdim=True))
    return projections

def cosine_distance_torch(x1, x2=None, eps=1e-8):
    x2 = x1 if x2 is None else x2
    w1 = x1.norm(p=2, dim=1, keepdim=True)
    w2 = w1 if x2 is x1 else x2.norm(p=2, dim=1, keepdim=True)
    return torch.mean(torch.abs(torch.mm(x1, x2.t()) / (w1 * w2.t()).clamp(min=eps)))

def set_requires_grad(model,booleen):
    for param in model.parameters():
        param.requires_grad = booleen

def distributional_min_sse(first_samples, second_samples, fdistrib, fs, ft, fdistrib_optim, fs_optim, ft_optim, 
                    dim_latent=10, nproj=20, num_epochs = 10, num_sup_iter = 10, num_inf_iter = 10, lam=1, lam_ap=1,device='cpu', 
                    proj_map=None, return_p=False,verbose=False):
    """
    heavily based on https://github.com/VinAIResearch/DSW/blob/master/utils.py
    # distributional min SSE defined as
    \inf_{fs, ft} \sup_{fdistrib \in M} \esp_{\theta \in S^{d-1}} W_p^p(fs(fdistrib(\theta), xs), ft(fdistrib(\theta), xt))
    
    
    * fdistrib : maps in the latent space. looks for the best distribution of the projection
    * fs : maps from the latent space (space of the random vector) to the orginal space 
         the random vector is mapped to a vector into the original space of the source
         
    
    """
    
    num_projections = nproj
    p = 2 
    
    # sampling uniformly theta over the unit sphere of dim $d$
    # $d$ = dim_latent
    if proj_map == None:
        proj = rand_projections(dim_latent, num_projections).float().to(device)
    
    
    first_samples_detach = first_samples.detach()
    second_samples_detach = second_samples.detach()
    
    for epoch in range(num_epochs):
        
        # optimize f_distrib 
        set_requires_grad(fs, False)
        set_requires_grad(ft, False)
        set_requires_grad(fdistrib, True)
        
        for j in range(num_sup_iter):
            
            # distributional projections
            proj_d = fdistrib(proj)
            
            # projections using fs and ft (from latent space to ambient spaces)
            projections_s = fs(proj_d)
            projections_t = ft(proj_d)
        
            cos_s = cosine_distance_torch(projections_s, projections_s)
            cos_t = cosine_distance_torch(projections_t, projections_t)
            reg = lam * (cos_s + cos_t ) 
            
            
            proj_dot = proj_d@proj_d.T
            reg_angle_pres_s = torch.sum((proj_dot - projections_s@projections_s.T)**2)
            reg_angle_pres_t = torch.sum((proj_dot - projections_t@projections_t.T)**2)
            reg_angle_pres = lam_ap * (reg_angle_pres_s + reg_angle_pres_t)
            
            encoded_projections_ = first_samples_detach.matmul(projections_s.transpose(0, 1))
            distribution_projections_ = (second_samples_detach.matmul(projections_t.transpose(0, 1)))  
            encoded_projections = encoded_projections_ #- torch.mean(encoded_projections_,dim=0)
            distribution_projections = distribution_projections_ #- torch.mean(distribution_projections_,dim=0)  
            wasserstein_distance = torch.abs((torch.sort(encoded_projections.transpose(0, 1), dim=1)[0] -
                                              torch.sort(distribution_projections.transpose(0, 1), dim=1)[0]))
            wasserstein_distance = torch.pow(torch.sum(torch.pow(wasserstein_distance, p), dim=1), 1. / p)
            sse_distance = torch.pow(torch.pow(wasserstein_distance, p).mean(), 1. / p)
            
            # maximize f_distrib : optimeze the distribution of the random projection
            loss = reg_angle_pres + reg - sse_distance
            
            fdistrib_optim.zero_grad()
            loss.backward()
            fdistrib_optim.step()
            if verbose: 
                print(f"{j}\t\t{loss.item()}\t\t{sse_distance.item()} \t\t {reg_angle_pres.item()} ***")

        
            
        # optimize fs and ft
        set_requires_grad(fs, True)
        set_requires_grad(ft, True)
        set_requires_grad(fdistrib, False)
        
        for i in range(num_inf_iter):
            
            # distributional projections
            proj_d = fdistrib(proj)
            
            # projections using fs and ft (from latent space to ambient spaces)
            projections_s = fs(proj_d)
            projections_t = ft(proj_d)
        
            cos_s = cosine_distance_torch(projections_s, projections_s)
            cos_t = cosine_distance_torch(projections_t, projections_t)
    
            reg = lam * (cos_s + cos_t ) 
            
            proj_dot = proj_d@proj_d.T
            reg_angle_pres_s = torch.sum((proj_dot - projections_s@projections_s.T)**2)
            reg_angle_pres_t = torch.sum((proj_dot - projections_t@projections_t.T)**2)
            reg_angle_pres = lam_ap*(reg_angle_pres_s + reg_angle_pres_t)
            
            encoded_projections_ = first_samples_detach.matmul(projections_s.transpose(0, 1))
            distribution_projections_ = (second_samples_detach.matmul(projections_t.transpose(0, 1)))  
            encoded_projections = encoded_projections_ #- torch.mean(encoded_projections_,dim=0)
            distribution_projections = distribution_projections_ #- torch.mean(distribution_projections_,dim=0)  
            wasserstein_distance = torch.abs((torch.sort(encoded_projections.transpose(0, 1), dim=1)[0] -
                                              torch.sort(distribution_projections.transpose(0, 1), dim=1)[0]))
            wasserstein_distance = torch.pow(torch.sum(torch.pow(wasserstein_distance, p), dim=1), 1. / p)
            sse_distance = torch.pow(torch.pow(wasserstein_distance, p).mean(), 1. / p)
            

            loss = reg + sse_distance + reg_angle_pres
            fs_optim.zero_grad()
            ft_optim.zero_grad()
        
            loss.backward()
            fs_optim.step()
            ft_optim.step()
            if verbose: 
                print(f"{i}\t\t{loss.item()}\t\t{sse_distance.item()}\t\t{reg_angle_pres.item()}")
    
    # final evaluation
        # distributional projections
    proj_d = fdistrib(proj)
    projections_s = fs(proj_d).detach()
    projections_t = ft(proj_d).detach()
    
    
    encoded_projections_ = first_samples.matmul(projections_s.transpose(0, 1))
    distribution_projections_ = (second_samples.matmul(projections_t.transpose(0, 1)))
    encoded_projections = encoded_projections_ #- torch.mean(encoded_projections_,dim=0)
    distribution_projections = distribution_projections_ #- torch.mean(distribution_projections_,dim=0)
    wasserstein_distance = torch.abs((torch.sort(encoded_projections.transpose(0, 1), dim=1)[0] -
                                      torch.sort(distribution_projections.transpose(0, 1), dim=1)[0]))
    wasserstein_distance = torch.pow(torch.sum(torch.pow(wasserstein_distance, p), dim=1), 1. / p)
    sse_distance = torch.pow(torch.pow(wasserstein_distance, p).mean(), 1. / p)
    if return_p == False:
        return sse_distance
    elif return_p== True:
        return sse_distance,proj
   



def min_sse(first_samples,second_samples, fs, ft, fs_optim, ft_optim, dim_latent=10, nproj=20,
                max_iter=50,lam=1,device='cpu',proj=None,
                return_p=False, verbose = False):
#  heavily based on
# from https://github.com/VinAIResearch/DSW/blob/master/utils.py

    #dim_s= first_samples.shape[1]
    #dim_t= second_samples.shape[1]
    num_projections = nproj
    p =2 
    if proj==None:
        proj = rand_projections(dim_latent, num_projections).float().to(device)
    
    
    # here samples are not optimized
    first_samples_detach = first_samples.detach()
    second_samples_detach = second_samples.detach()
    for i in range(max_iter):
        projections_s = fs(proj)
        projections_t = ft(proj)
    
        cos_s = cosine_distance_torch(projections_s, projections_s)
        cos_t = cosine_distance_torch(projections_t, projections_t)
        #cos_st = cosine_distance_torch(projections_s, projections_t)

        reg = lam * (cos_s + cos_t ) 
        #reg = - 0 * (cos_st) + lam * cos_s + lam*cos_t

        
  
        encoded_projections_ = first_samples_detach.matmul(projections_s.transpose(0, 1))
        distribution_projections_ = (second_samples_detach.matmul(projections_t.transpose(0, 1)))  
        encoded_projections = encoded_projections_ #- torch.mean(encoded_projections_,dim=0)
        distribution_projections = distribution_projections_ #- torch.mean(distribution_projections_,dim=0)  
        wasserstein_distance = torch.abs((torch.sort(encoded_projections.transpose(0, 1), dim=1)[0] -
                                          torch.sort(distribution_projections.transpose(0, 1), dim=1)[0]))
        wasserstein_distance = torch.pow(torch.sum(torch.pow(wasserstein_distance, p), dim=1), 1. / p)
        sse_distance = torch.pow(torch.pow(wasserstein_distance, p).mean(), 1. / p)

        
            
        
        loss = reg + sse_distance
        fs_optim.zero_grad()
        ft_optim.zero_grad()
        
        loss.backward()
        fs_optim.step()
        ft_optim.step()
        if verbose:
            print(f"{i}\t\t{loss.item()}\t\t{sse_distance.item()}")
    
    # final evaluation
    

    projections_s = fs(proj).detach()
    projections_t = ft(proj).detach()
    
    
    encoded_projections_ = first_samples.matmul(projections_s.transpose(0, 1))
    distribution_projections_ = (second_samples.matmul(projections_t.transpose(0, 1)))
    encoded_projections = encoded_projections_ #- torch.mean(encoded_projections_,dim=0)
    distribution_projections = distribution_projections_ #- torch.mean(distribution_projections_,dim=0)
    wasserstein_distance = torch.abs((torch.sort(encoded_projections.transpose(0, 1), dim=1)[0] -
                                      torch.sort(distribution_projections.transpose(0, 1), dim=1)[0]))
    wasserstein_distance = torch.pow(torch.sum(torch.pow(wasserstein_distance, p), dim=1), 1. / p)
    sse_distance = torch.pow(torch.pow(wasserstein_distance, p).mean(), 1. / p)
    if return_p == False:
        return sse_distance
    elif return_p== True:
        return sse_distance,proj





def distrib_sse(first_samples,second_samples, fs, ft, fs_optim, ft_optim, nproj=20,
                max_iter=50,lam=0.001,device='cpu',ps=None,pt=None,
                return_p=False):
#  heavily based on
# from https://github.com/VinAIResearch/DSW/blob/master/utils.py

    dim_s= first_samples.shape[1]
    dim_t= second_samples.shape[1]
    p = 2
    num_projections = nproj
    
    if ps==None and pt==None:
        ps = rand_projections(dim_s, num_projections).float().to(device)
        pt = rand_projections(dim_t, num_projections).float().to(device)
        
        #ps = rand_projections(dim_s, num_projections).to(device)
        #pt = rand_projections(dim_t, num_projections).to(device)
        
        #pt = ps.clone()
    
    
    # here samples are not optimized
    first_samples_detach = first_samples.detach()
    second_samples_detach = second_samples.detach()
    for i in range(max_iter):
        projections_s = fs(ps)
        projections_t = ft(pt)
    
        cos_s = cosine_distance_torch(projections_s, projections_s)
        cos_t = cosine_distance_torch(projections_t, projections_t)
    
        reg = lam * (cos_s + cos_t) 
        
  
        encoded_projections_ = first_samples_detach.matmul(projections_s.transpose(0, 1))
        distribution_projections_ = (second_samples_detach.matmul(projections_t.transpose(0, 1)))  
        encoded_projections = encoded_projections_ #- torch.mean(encoded_projections_,dim=0)
        distribution_projections = distribution_projections_ #- torch.mean(distribution_projections_,dim=0)  
        wasserstein_distance = torch.abs((torch.sort(encoded_projections.transpose(0, 1), dim=1)[0] -
                                          torch.sort(distribution_projections.transpose(0, 1), dim=1)[0]))
        wasserstein_distance = torch.pow(torch.sum(torch.pow(wasserstein_distance, p), dim=1), 1. / p)
        sse_distance = torch.pow(torch.pow(wasserstein_distance, p).mean(), 1. / p)

        # maximizing with respect to the regularization 
        loss = reg - sse_distance
        fs_optim.zero_grad()
        ft_optim.zero_grad()
        
        loss.backward()
        fs_optim.step()
        ft_optim.step()
        #print(i,loss.item())
    
    # final evaluation
    

    projections_s = fs(ps).detach()
    projections_t = ft(pt).detach()
    
    
    encoded_projections_ = first_samples.matmul(projections_s.transpose(0, 1))
    distribution_projections_ = (second_samples.matmul(projections_t.transpose(0, 1)))
    encoded_projections = encoded_projections_ #- torch.mean(encoded_projections_,dim=0)
    distribution_projections = distribution_projections_ #- torch.mean(distribution_projections_,dim=0)
    wasserstein_distance = torch.abs((torch.sort(encoded_projections.transpose(0, 1), dim=1)[0] -
                                      torch.sort(distribution_projections.transpose(0, 1), dim=1)[0]))
    wasserstein_distance = torch.pow(torch.sum(torch.pow(wasserstein_distance, p), dim=1), 1. / p)
    sse_distance = torch.pow(torch.pow(wasserstein_distance, p).mean(), 1. / p)
    if return_p == False:
        return sse_distance
    elif return_p== True:
        return sse_distance,ps,pt


def sse_gpu(xs,xt,device,nproj=200,tolog=False,Ps=None,Pt=None,p=2):
    """ project both spaces using different projection and then
        compute Sliced W on projected distribution
    """
    if tolog:
        log={}

    if tolog: 
        st=time.time()
        xsp,xtp=sink_2P_(xs,xt,device,nproj,Ps,Pt)
        ed=time.time()   
        log['time_sink_']=ed-st
    else:
        xsp,xtp=sink_2P_(xs,xt,device,nproj,Ps,Pt)
    if tolog:    
        st=time.time()
        #d,log_gw1d=gromov_1d(xsp,xtp,tolog=True)
        wasserstein_distance = torch.abs((torch.sort(xsp.transpose(0, 1), dim=1)[0] -
                                          torch.sort(xtp.transpose(0, 1), dim=1)[0]))
        wasserstein_distance = torch.pow(torch.sum(torch.pow(wasserstein_distance, p), dim=1), 1. / p)
        d = torch.pow(torch.pow(wasserstein_distance, p).mean(), 1. / p)
        
        
        ed=time.time()   
        log['time_w_1D']=ed-st
    else:
        wasserstein_distance = torch.abs((torch.sort(xsp.transpose(0, 1), dim=1)[0] -
                                          torch.sort(xtp.transpose(0, 1), dim=1)[0]))
        wasserstein_distance = torch.pow(torch.sum(torch.pow(wasserstein_distance, p), dim=1), 1. / p)
        d = torch.pow(torch.pow(wasserstein_distance, p).mean(), 1. / p)
    if tolog:
        return d,log
    else:
        return d


def ssegromov_gpu(xs,xt,device,nproj=200,tolog=False,Ps=None,Pt=None):
    """ project both spaces using different projection and then
        compute 1D-GW on projected distribution
    """
    if tolog:
        log={}

    if tolog: 
        st=time.time()
        xsp,xtp=sink_2P_(xs,xt,device,nproj,Ps,Pt)
        ed=time.time()   
        log['time_sink_']=ed-st
    else:
        xsp,xtp=sink_2P_(xs,xt,device,nproj,Ps,Pt)
    if tolog:    
        st=time.time()
        d,log_gw1d=gromov_1d(xsp,xtp,tolog=True)
        ed=time.time()   
        log['time_gw_1D']=ed-st
        log['gw_1d_details']=log_gw1d
    else:
        d=gromov_1d(xsp,xtp,tolog=False)
    
    if tolog:
        return d,log
    else:
        return d

        

def sgw_gpu(xs,xt,device,nproj=200,tolog=False,P=None):
    """ Returns SGW between xs and xt eq (4) in [1]. Only implemented with the 0 padding operator Delta
    Parameters
    ----------
    xs : tensor, shape (n, p)
         Source samples
    xt : tensor, shape (n, q)
         Target samples
    device :  torch device
    nproj : integer
            Number of projections. Ignore if P is not None
    P : tensor, shape (max(p,q),n_proj)
        Projection matrix. If None creates a new projection matrix
    tolog : bool
            Wether to return timings or not
    Returns
    -------
    C : tensor, shape (n_proj,1)
           Cost for each projection
    References
    ----------
    .. [1] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
          and Courty Nicolas
          "Sliced Gromov-Wasserstein"
    Example
    ----------
    import numpy as np
    import torch
    from sgw_pytorch import sgw
    
    n_samples=300
    Xs=np.random.rand(n_samples,2)
    Xt=np.random.rand(n_samples,1)
    xs=torch.from_numpy(Xs).to(torch.float32)
    xt=torch.from_numpy(Xt).to(torch.float32)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    P=np.random.randn(2,500)
    sgw_gpu(xs,xt,device,P=torch.from_numpy(P).to(torch.float32))
    """    
    if tolog:
        log={}

    if tolog: 
        st=time.time()
        xsp,xtp=sink_(xs,xt,device,nproj,P)
        ed=time.time()   
        log['time_sink_']=ed-st
    else:
        xsp,xtp=sink_(xs,xt,device,nproj,P)
    if tolog:    
        st=time.time()
        d,log_gw1d=gromov_1d(xsp,xtp,tolog=True)
        ed=time.time()   
        log['time_gw_1D']=ed-st
        log['gw_1d_details']=log_gw1d
    else:
        d=gromov_1d(xsp,xtp,tolog=False)
    
    if tolog:
        return d,log
    else:
        return d

        
        

def _cost(xsp,xtp,tolog=False):   
    """ Returns the GM cost eq (3) in [1]
    Parameters
    ----------
    xsp : tensor, shape (n, n_proj)
         1D sorted samples (after finding sigma opt) for each proj in the source
    xtp : tensor, shape (n, n_proj)
         1D sorted samples (after finding sigma opt) for each proj in the target
    tolog : bool
            Wether to return timings or not
    Returns
    -------
    C : tensor, shape (n_proj,1)
           Cost for each projection
    References
    ----------
    .. [1] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
          and Courty Nicolas
          "Sliced Gromov-Wasserstein"
    """
    st=time.time()

    xs=xsp
    xt=xtp

    xs2=xs*xs
    xs3=xs2*xs
    xs4=xs2*xs2

    xt2=xt*xt
    xt3=xt2*xt
    xt4=xt2*xt2

    X=torch.sum(xs,0)
    X2=torch.sum(xs2,0)
    X3=torch.sum(xs3,0)
    X4=torch.sum(xs4,0)
    
    Y=torch.sum(xt,0)
    Y2=torch.sum(xt2,0)
    Y3=torch.sum(xt3,0)
    Y4=torch.sum(xt4,0)
    
    xxyy_=torch.sum((xs2)*(xt2),0)
    xxy_=torch.sum((xs2)*(xt),0)
    xyy_=torch.sum((xs)*(xt2),0)
    xy_=torch.sum((xs)*(xt),0)
    
            
    n=xs.shape[0]

    C2=2*X2*Y2+2*(n*xxyy_-2*Y*xxy_-2*X*xyy_+2*xy_*xy_)

    power4_x=2*n*X4-8*X3*X+6*X2*X2
    power4_y=2*n*Y4-8*Y3*Y+6*Y2*Y2

    C=(1/(n**2))*(power4_x+power4_y-2*C2)
        
        
    ed=time.time()
    
    if not tolog:
        return C 
    else:
        return C,ed-st


def gromov_1d(xs,xt,tolog=False): 
    """ Solves the Gromov in 1D (eq (2) in [1] for each proj
    Parameters
    ----------
    xsp : tensor, shape (n, n_proj)
         1D sorted samples for each proj in the source
    xtp : tensor, shape (n, n_proj)
         1D sorted samples for each proj in the target
    tolog : bool
            Wether to return timings or not
    fast: use the O(nlog(n)) cost or not
    Returns
    -------
    toreturn : tensor, shape (n_proj,1)
           The SGW cost for each proj
    References
    ----------
    .. [1] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
          and Courty Nicolas
          "Sliced Gromov-Wasserstein"
    """
    
    if tolog:
        log={}
    
    st=time.time()
    xs2,i_s=torch.sort(xs,dim=0)
    
    if tolog:
        xt_asc,i_t=torch.sort(xt,dim=0) #sort increase
        xt_desc,i_t=torch.sort(xt,dim=0,descending=True) #sort deacrese
        l1,t1=_cost(xs2,xt_asc,tolog=tolog)
        l2,t2=_cost(xs2,xt_desc,tolog=tolog)
    else:
        xt_asc,i_t=torch.sort(xt,dim=0)
        xt_desc,i_t=torch.sort(xt,dim=0,descending=True)
        l1=_cost(xs2,xt_asc,tolog=tolog)
        l2=_cost(xs2,xt_desc,tolog=tolog)   
    toreturn=torch.mean(torch.min(l1,l2)) 
    ed=time.time()  
   
    if tolog:
        log['g1d']=ed-st
        log['t1']=t1
        log['t2']=t2
 
    if tolog:
        return toreturn,log
    else:
        return toreturn
            
def sink_(xs,xt,device,nproj=200,P=None): #Delta operator (here just padding)
    """ Sinks the points of the measure in the lowest dimension onto the highest dimension and applies the projections.
    Only implemented with the 0 padding Delta=Delta_pad operator (see [1])
    Parameters
    ----------
    xs : tensor, shape (n, p)
         Source samples
    xt : tensor, shape (n, q)
         Target samples
    device :  torch device
    nproj : integer
            Number of projections. Ignored if P is not None
    P : tensor, shape (max(p,q),n_proj)
        Projection matrix
    Returns
    -------
    xsp : tensor, shape (n,n_proj)
           Projected source samples 
    xtp : tensor, shape (n,n_proj)
           Projected target samples 
    References
    ----------
    .. [1] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
          and Courty Nicolas
          "Sliced Gromov-Wasserstein"
    """  
    dim_d= xs.shape[1]
    dim_p= xt.shape[1]
    
    if dim_d<dim_p:
        random_projection_dim = dim_p
        xs2=torch.cat((xs,torch.zeros((xs.shape[0],dim_p-dim_d)).to(device)),dim=1)
        xt2=xt
    else:
        random_projection_dim = dim_d
        xt2=torch.cat((xt,torch.zeros((xt.shape[0],dim_d-dim_p)).to(device)),dim=1)
        xs2=xs
     
    if P is None:
        P=torch.randn(random_projection_dim,nproj)
    p=P/torch.sqrt(torch.sum(P**2,0,True))
    
    try:
        xsp=torch.matmul(xs2,p.to(device))
        xtp=torch.matmul(xt2,p.to(device))
    except RuntimeError as error:
        print('----------------------------------------')
        print('xs origi dim :', xs.shape)
        print('xt origi dim :', xt.shape)
        print('dim_p :', dim_p)
        print('dim_d :', dim_d)
        print('random_projection_dim : ',random_projection_dim)
        print('projector dimension : ',p.shape)
        print('xs2 dim :', xs2.shape)
        print('xt2 dim :', xt2.shape)
        print('xs_tmp dim :', xs2.shape)
        print('xt_tmp dim :', xt2.shape)
        print('----------------------------------------')
        print(error)
        raise BadShapeError
    
    return xsp,xtp

           
def sink_2P_(xs,xt,device,nproj=200,Ps=None,Pt=None): #Delta operator (here just padding)
  
    dim_s= xs.shape[1]
    dim_t= xt.shape[1]
    
 
     
    if Ps is None:
        Ps=torch.randn(dim_s,nproj)
    ps=Ps/torch.sqrt(torch.sum(Ps**2,0,True))
    
    if Pt is None:
        Pt=torch.randn(dim_t,nproj)
    pt=Pt/torch.sqrt(torch.sum(Pt**2,0,True))
    
    xsp=torch.matmul(xs,ps.to(device))
    xtp=torch.matmul(xt,pt.to(device))
    
    try:
    
        xsp=torch.matmul(xs,ps.to(device))
        xtp=torch.matmul(xt,pt.to(device))
    except :
        print('----------------------------------------')
        print('xs origi dim :', xs.shape)
        print('xt origi dim :', xt.shape)
        print('dim_s :', dim_s)
        print('dim_t :', dim_t)

    
    return xsp,xtp

