from cca_zoo.models import KCCA,MCCA,PRCCA
import pdb




from utils import Initialize_Seed
def Ori_CCA_fit_transform(multi_view_train,multi_view_test,dim=100,epochs=20,method='cca'):
    
    Initialize_Seed(2)
    
    if method=='cca':
        linear_cca = MCCA(latent_dims=200,random_state=2)
    if method=='kcca':
        linear_cca = KCCA(latent_dims=200,random_state=2) #70
    if method=='prcca':
        linear_cca = PRCCA(latent_dims=200,random_state=2) #80
    #pdb.set_trace()
    

    linear_cca.fit(multi_view_train)
    #linear_cca.score
    
    #pdb.set_trace()
    res = linear_cca.transform(multi_view_test)
    # weights = linear_cca.weights
    # pdb.set_trace()
    #pdb.set_trace()
    return res#,score