
import numpy as np
import torch

import random


from cca_zoo.data.deep import get_dataloaders
import pytorch_lightning as pl
import numpy as np
import pdb
from .numpydataset import NumpyDataset
from .model_AE import DCCA_Noise_Norm_M 



from tqdm import tqdm
import copy



def MyDCCA_full_fit_transform(train,test,dim=100,out_dim=200,epochs=50,num_views=2,noise='normal',loss_name='cca',approach=True,recon=False,private=False,a=1, lr = 10*1e-5):
 
    torch.set_default_dtype(torch.float64)
  
   
    print(len(train[0]))
    print(len(test[0]))
    print(lr)

 
    in_dims = [view.shape[1] for view in train]
    #print(in_dims)
    #exit(0)
    dataset = NumpyDataset(train, labels=None,loss_name=loss_name)
    num_views = len(in_dims)
  
    loss_fn = torch.nn.MSELoss(reduction='mean')
   
   
    model = DCCA_Noise_Norm_M(in_dims=in_dims,out_dim=out_dim,view_num=num_views,loss_name=loss_name,recon=recon,private=private).cuda()
    #pdb.set_trace()
    optimzer = torch.optim.Adam(model.parameters(), lr=lr)
   
    encoder = []
    encoder = copy.deepcopy(model.encoder)
    
    model.train()
    f_score = 0
   
    bs = min(len(train[0]),2000)
    #bs = len(train[0])
    dataloader = get_dataloaders(dataset=dataset,batch_size=bs,shuffle_train=False,drop_last=True,num_workers=1)
    gts = []
    pre_loss = 0
    for epoch in tqdm(range(epochs)):
       
        for batch_idx, batch in enumerate(dataloader):

           
            batch['views'] = [view.cuda() for view in batch['views']]
         
            noise_tensor =  []
          
            
      
            for i in range(len(batch['views'])):
                shape = batch['views'][i].shape
                
                #    shape = (300,shape[1],)
                noise_add = torch.cuda.DoubleTensor(shape)
                torch.randn(shape, out=noise_add)
                  
                noise_tensor.append(noise_add)

            if noise =='normal':
                view_project,noise_project = model(batch['views'],noise_tensor)
            
                loss_1 = torch.zeros(1).cuda()
                #loss_3 = torch.zeros(1).cuda()
                try:
                    pre_loss = loss_2.item()
                except:
                    pass
                loss_2 =   model.get_loss(view_project)
                f_score = loss_2.item()
                
                #print(epoch,loss_1.item(),loss_2.item())
               
                for i in range(len(batch['views'])):
                
                    
                    
                    #gt = gts[i]
                    with torch.no_grad():
                       gt = model.get_loss([noise_tensor[i][:,:out_dim],batch['views'][i][:,:out_dim]])
                       #gt =  model.get_loss([encoder[i](noise_tensor[i]),encoder[i](batch['views'][i])])
                        
                    here = model.get_loss([noise_project[i],view_project[i]])
                    

                    # CUB: 1.5 (0.921) 
                    # cal: 45 (0.611)
                    a= a
                    loss_1 = loss_1+a*torch.abs(gt-here)
                    #pdb.set_trace()
                #loss_1 = loss_1/len(batch['views'])
                #loss_1 = loss_1+1e-5
                
                print(epoch,a,gt.item(),here.item(),loss_1.item(),loss_2.item())
                    
                
                loss = loss_2 + loss_1
                 
                
               
            elif noise =='none':
                loss_1 = torch.zeros(1).cuda()
                if recon:
                    if private:
                        view_project,view_recon = model(batch['views'],recon=recon)
                        for i in range(len(batch['views'])):
                            loss_1 =  loss_1+ loss_fn(view_recon[i],batch['views'][i])
                        loss =  model.get_loss([view[:,:out_dim] for view in view_project])
                        f_score = loss.item()
                        loss = loss+loss_1
                    else:
                        view_project,view_recon = model(batch['views'],recon=recon)
                        for i in range(len(batch['views'])):
                            loss_1 =  loss_1+ loss_fn(view_recon[i],batch['views'][i])
                        #pdb.set_trace()
                        loss =  model.get_loss(view_project)
                        f_score = loss.item()
                        loss = loss+loss_1
                else:
                    view_project = model(batch['views'])
                    loss =  model.get_loss(view_project)
                    print(epoch,loss.item())
                    f_score = loss.item()
                
            optimzer.zero_grad()
            try:
                loss.backward()
            except:
                pdb.set_trace()
         
            optimzer.step()

    res = []
    #pdb.set_trace()
    with torch.no_grad():
        model.eval()
        res = []
 
        for i in range(0,len(test[0]),2000):
    #         #pdb.set_trace()
            batch_views = [torch.Tensor(view.astype(np.float64))[i:i+2000,:].cuda() for view in test]
    #         #pdb.set_trace()
            
            view_project = model(batch_views)
         
            for j in range(len(test)):
                data = view_project[j].cpu()
               
                try:
                    res[j].append(data)
                except:
                    res.append([data])
            #pdb.set_trace()
    for i in range(len(test)):
        res[i] = torch.cat(res[i],dim=0).numpy()
   
    return [res,f_score]