
import numpy as np
import torch


from cca_zoo.data.deep import get_dataloaders
import pytorch_lightning as pl
import numpy as np
import pdb
from .numpydataset import NumpyDataset
from .model_AE import DCCA_Noise_Norm_M 
from .utils import Initialize_Seed,Distance_Correlation
import math
import torch.optim.lr_scheduler as lr_scheduler
from sklearn.decomposition import FastICA
from .utils import cca_loss
from tqdm import tqdm
import copy



def MyDCCA_full_fit_transform(train,test,out_dim=100,epochs=50,num_views=2,noise='normal',loss_name='cca',recon=False,private=False,record_noise=True,lr= 50*1e-4 *3,a=200):
   
    torch.set_default_dtype(torch.float64)
    
    #N = multi_view[0].shape[0]
    #Y = multi_view[1]
    
    #multi_view_tensor = [torch.Tensor(view).cuda() for view in multi_view]
   # multi_view_data = [view[:N//10,:] for view in multi_view]
    in_dims = [view.shape[1] for view in train]
    dataset = NumpyDataset(train, labels=None,loss_name=loss_name)
    num_views = len(in_dims)
    #LATENT_DIMS = dim
    #EPOCHS = epochs
    loss_fn = torch.nn.MSELoss(reduction='mean')
   
   
    out_dim=out_dim
    lr = lr  # normal : 50*1e-4 *1   private :50*1e-4 *2  noise：50*1e-4 *3
    model = DCCA_Noise_Norm_M(in_dims=in_dims,out_dim=out_dim,view_num=num_views,loss_name=loss_name,recon=recon,private=private).cuda()
    #pdb.set_trace()
    optimzer = torch.optim.Adam(model.parameters(), lr=lr)
    model.train()
    f_score = 0
    #cka = CudaCKA(multi_view_tensor[0].device)
    #pdb.set_trace()
    dataloader = get_dataloaders(dataset=dataset,batch_size=2000,shuffle_train=False,drop_last=False,num_workers=1)
    gts = []
    pre_loss = 0
    noise_score = 0
    for epoch in tqdm(range(epochs)):
       
        for batch_idx, batch in enumerate(dataloader):

            #torch.cuda.empty_cache()
            batch['views'] = [view.cuda() for view in batch['views']]
            #shape= batch['views'][0].shape
            noise_tensor =  []
           
            
            for i in range(len(batch['views'])):
                shape = batch['views'][i].shape
                if loss_name!='tcca' or True:

                    noise_add = torch.cuda.DoubleTensor(shape)
                    torch.randn(shape, out=noise_add)
                   
                    noise_tensor.append(noise_add)

            if noise =='normal':
                #batch['views'] = [(view - view.mean(dim=0))/view.std(dim=0) for view in batch['views']]
                view_project,noise_project = model(batch['views'],noise_tensor)
              
                loss_1 = torch.zeros(1).cuda()
                #loss_3 = torch.zeros(1).cuda()
                try:
                    pre_loss = loss_2.item()
                except:
                    pass
                loss_2 =   model.get_loss(view_project)
                f_score = loss_2.item()
                
                #print(epoch,loss_1.item(),loss_2.item())
                
                for i in range(len(batch['views'])):
                    with torch.no_grad():
                        # noise_add- torch.mean(noise_add,dim=0,keepdim=True)
                        gt =  model.get_loss([noise_tensor[i],batch['views'][i]])
                    
                    here = model.get_loss([noise_project[i],view_project[i]])
                    noise_score = here.item()
                    #k_a = 200
                    #k_a=1
                    loss_1 = loss_1+a*torch.abs(gt-here)
                    #pdb.set_trace()
                print(epoch,gt.item(),here.item(),loss_1.item(),loss_2.item())
                
            
                loss = loss_2 + loss_1 


            elif noise =='none':
                loss_1 = torch.zeros(1).cuda()
                if recon:
                    if private:
                        view_project,view_recon = model(batch['views'],recon=recon)
                        for i in range(len(batch['views'])):
                            loss_1 =  loss_1+ loss_fn(view_recon[i],batch['views'][i])
                        loss =  model.get_loss([view[:,:out_dim] for view in view_project])
                        f_score = loss.item()
                        print(loss.item(),loss_1.item())
                        loss = loss+ loss_1
                    else:
                        view_project,view_recon = model(batch['views'],recon=recon)
                        for i in range(len(batch['views'])):
                            loss_1 =  loss_1+ loss_fn(view_recon[i],batch['views'][i])
                        #pdb.set_trace()
                        loss =  model.get_loss(view_project)
                        f_score = loss.item()
                        print(loss.item(),loss_1.item())
                        loss = loss+ loss_1
                else:
                    view_project = model(batch['views'])
                    loss =  model.get_loss(view_project)
                    print(epoch,loss.item())
                    f_score = loss.item()
                if record_noise:
                    with torch.no_grad():
                        here =  model.get_loss([model.encoder[i](noise_tensor[i]),model.encoder[i](batch['views'][i])])
                        noise_score = here.item()

            optimzer.zero_grad()
            try:
                loss.backward()
            except:
                pdb.set_trace()
          
            optimzer.step()
    #pdb.set_trace()
    #pdb.set_trace()
    res = []
    #pdb.set_trace()
    with torch.no_grad():
        model.eval()
        res = []
 
        for i in range(0,len(test[0]),2000):
    #         #pdb.set_trace()
            batch_views = [torch.Tensor(view)[i:i+2000,:].cuda() for view in test]
    #         #pdb.set_trace()
            
            view_project = model(batch_views)
            #pdb.set_trace()

            #embedding_here =  [view.cpu().detach().numpy() for view in view_project]
    #         #pdb.set_trace()
            for j in range(len(test)):
                data = view_project[j].cpu()
                #print(i,j)
                #pdb.set_trace()
                try:
                    res[j].append(data)
                except:
                    res.append([data])
            #pdb.set_trace()
    for i in range(len(test)):
        res[i] = torch.cat(res[i],dim=0).numpy()
    #pdb.set_trace()

    # #pdb.set_trace()

    # #trainer = trainer.to('cpu')
    return [res,f_score,noise_score]