
import numpy as np
import torch

import cca_zoo
from cca_zoo.data.deep import get_dataloaders
import pytorch_lightning as pl
import numpy as np
import pdb
from .numpydataset import NumpyDataset
from .model_AE import DCCA_Noise_Norm_M 
from .utils import Initialize_Seed,Distance_Correlation
import math
import torch.optim.lr_scheduler as lr_scheduler
from sklearn.decomposition import FastICA
from .utils import cca_loss
from tqdm import tqdm
import copy
import torch.nn.functional as F


    
def covariance_loss(z1: torch.Tensor,average=False) -> torch.Tensor:
   
    N, D = z1.size()
    
   # z1 = z1 - z1.mean(dim=0)
  
    #cov_z1 = (z1.T @ z1)  / (N - 1)
    cov_z1 = torch.corrcoef(z1.T)
    #return torch.trace(cov_z1)/math.sqrt(D*D)
    diag = torch.eye(D, device=z1.device)
    #pdb.set_trace()
    if average:
        cov_loss = torch.abs(cov_z1[~diag.bool()]).mean() #/(D*D-D)
    else:
        cov_loss = torch.norm(cov_z1[~diag.bool()])**2
    
    return cov_loss



def l2_reg_ortho_loss_func(mdl,device='cuda',weight = 1e-2,method='risp',average=False):
    #return 1
    l2_reg = []
        
    reg=None
    #pdb.set_trace()
    #print('nmsl')
    #return 2
    for W in mdl.parameters():  # output * input
        #pdb.set_trace()
                #exit(0)
        if W.ndimension() < 2:
            continue
        else:
                    #print(W.shape)
                    #pdb.set_trace()
            if method=='cor' :
                reg =  covariance_loss(W.t(),average=average)
            elif method =='nesum':
                reg =  nesum(W.t())
            elif method == 'debug':
                pdb.set_trace()
            
            l2_reg.append(reg)
      
    return sum(l2_reg)/len(l2_reg)


def nesum(tensor):
    z1  = tensor
          
  
    cov_z1 = torch.corrcoef(z1.T)
            
    eigenvalues = torch.linalg.eigvalsh(cov_z1)
    #U, eigenvalues, V = torch.linalg.svd(tensor, full_matrices=False)

            # 计算特征值最大值
    eigenvalue_max = torch.max(eigenvalues)

    #eigenvalue_min = torch.min(eigenvalues)

            # 计算特征值之和
    eigenvalue_sum = torch.mean(eigenvalues/eigenvalue_max)

            # 计算比值
    #ratio = eigenvalue_max / eigenvalue_sum
    #pdb.set_trace()
    return eigenvalue_sum

errors_6 = []

def mean_variance_normalize(input_tensor,sigma_squared=0.1):
    """
    对给定的N*dim tensor进行均值方差归一化
    """
    # 计算每个特征的均值和标准差
    #计算当前tensor的均值
    mean = torch.mean(input_tensor, dim=0, keepdim=True)
    
    # 减去均值以得到0均值
    zero_mean_tensor = input_tensor - mean
    
    # 计算当前tensor的标准差
    std = torch.std(zero_mean_tensor, dim=0, keepdim=True)+1e-10
    
    
    # 根据指定的方差调整tensor的标准差
    desired_std = sigma_squared
    normalized_tensor = zero_mean_tensor * (desired_std / std)
    
    return normalized_tensor

def MyDCCA_full_fit_transform(train,test,out_dim=100,epochs=50,num_views=2,noise='normal',loss_name='cca',recon=False,private=False,record_noise=True,lr= 50*1e-4 *3,a=200,linear=False,regular=None):
   
    torch.set_default_dtype(torch.float64)
    #pdb.set_trace()
    #N = multi_view[0].shape[0]
    #Y = multi_view[1]
    
    #multi_view_tensor = [torch.Tensor(view).cuda() for view in multi_view]
   # multi_view_data = [view[:N//10,:] for view in multi_view]
    in_dims = [view.shape[1] for view in train]
    dataset = NumpyDataset(train, labels=None,loss_name=loss_name)
    num_views = len(in_dims)
    #LATENT_DIMS = dim
    #EPOCHS = epochs
    loss_fn = torch.nn.MSELoss(reduction='mean')
   
   
    out_dim=out_dim
    lr = lr  # normal : 50*1e-4 *1   private :50*1e-4 *2  noise：50*1e-4 *3
    model = DCCA_Noise_Norm_M(in_dims=in_dims,out_dim=out_dim,view_num=num_views,loss_name=loss_name,recon=recon,private=private,linear=linear,noise=noise).cuda()
    #pdb.set_trace()
    #model.apply(init_weights)
    optimzer = torch.optim.Adam(model.parameters(), lr=lr)
    model.train()
    f_score = 0
    #cka = CudaCKA(multi_view_tensor[0].device)
    #pdb.set_trace()
    dataloader = get_dataloaders(dataset=dataset,batch_size=2000,shuffle_train=False,drop_last=False,num_workers=1)
    gts = []
    pre_loss = 0
    noise_score = 0
    for epoch in tqdm(range(epochs)):
       
        for batch_idx, batch in enumerate(dataloader):

            #torch.cuda.empty_cache()
            batch['views'] = [view.cuda() for view in batch['views']]
            #shape= batch['views'][0].shape
            noise_tensor =  []
           
            
            for i in range(len(batch['views'])):
                shape = batch['views'][i].shape
                if loss_name!='tcca' or True:

                    noise_add = torch.cuda.DoubleTensor(shape)
                    torch.randn(shape, out=noise_add)
                   
                    noise_tensor.append(noise_add)

            if noise =='normal':
                #batch['views'] = [(view - view.mean(dim=0))/view.std(dim=0) for view in batch['views']]
                view_project,noise_project = model(batch['views'],noise_tensor)
              
                loss_1 = torch.zeros(1).cuda()
                #loss_3 = torch.zeros(1).cuda()
                try:
                    pre_loss = loss_2.item()
                except:
                    pass
                loss_2 =   model.get_loss(view_project)
                f_score = loss_2.item()
                
                #print(epoch,loss_1.item(),loss_2.item())
                
                for i in range(len(batch['views'])):
                    with torch.no_grad():
                        # noise_add- torch.mean(noise_add,dim=0,keepdim=True)
                        gt =  model.get_loss([noise_tensor[i],batch['views'][i]]) # 方差0.022，loss -14
                        #gt =  model.get_loss([ mean_variance_normalize(noise_tensor[i],1),mean_variance_normalize(batch['views'][i],1)])
                    
                    here = model.get_loss([noise_project[i],view_project[i]]) # 方差0.005 :loss -5
                    #here = model.get_loss([mean_variance_normalize(noise_project[i],1),mean_variance_normalize(view_project[i],1)]) 
                    #pdb.set_trace()
                    noise_score = here.item()
                    #k_a = 200
                    #k_a=1
                    loss_1 = loss_1+a*torch.abs(gt-here)
                    #pdb.set_trace()
                print(epoch,gt.item(),here.item(),loss_1.item(),loss_2.item())
                
            
                loss = loss_2 + loss_1 
            elif noise =='none':
                loss_1 = torch.zeros(1).cuda()
                if recon:
                    if private:
                        view_project,view_recon = model(batch['views'],recon=recon)
                        for i in range(len(batch['views'])):
                            loss_1 =  loss_1+ 0.01*loss_fn(view_recon[i],batch['views'][i])
                        loss =  model.get_loss([view[:,:out_dim] for view in view_project])
                        f_score = loss.item()
                        loss = loss+loss_1
                    else:
                        if regular:
                            view_project= model(batch['views'])
                            #view_project = []
                            #view_project,view_recon = model([0.3*batch['views'][k]+ 0.7*noise_tensor[k] for k in range(len(batch['views']))],recon=recon)
                            for i in range(len(batch['views'])):
                                # import pdb
                                # pdb.set_trace()
                                #view_recon_here = model.decoder[i](model.encoder[i](0.1*batch['views'][i]+ 0.9*noise_tensor[i]))
                                #r_here,out = myforward(model.encoder[i],batch['views'][i])
                                #view_project.append(r_here)
                                #pdb.set_trace()
                                r_here = view_project[i]
                                view_recon_here = model.decoder[i](r_here) #
                                #view_project.append(r_here)
                                # if epoch>98:
                                #pdb.set_trace()
                                #decov_loss = covariance_loss(model.encoder[i].layers(batch['views'][i]))
                                #pdb.set_trace()
                                decov_loss =   200*l2_reg_ortho_loss_func(model.encoder[i],method='mma',weight=1) #+ l2_reg_ortho_loss_func(model.decoder[i].layers[0],method='cor',weight=1)
                                #decov_loss = 0.001*covariance_loss(r_here) #+ 0.001*variance_loss(r_here)
                                #     pdb.set_trace()
                                #decov_loss = covariance_loss(r_here)
                               # va_loss = variance_loss(r_here)
                                #print(decov_loss,variance_loss(r_here))
                                #print(decov_loss,nesum(r_here))
                                #pdb.set_trace()
                                # if epoch>500:
                                #     pdb.set_trace()
                                #     covariance_loss(r_here)
                                loss_1 =  loss_1  +  0.01* loss_fn(view_recon_here,batch['views'][i]) + decov_loss# +0.1*va_loss 

                        else:
                            view_project,view_recon = model(batch['views'],recon=recon)
                            for i in range(len(batch['views'])):
                                loss_1 =  loss_1+ 0.01* loss_fn(view_recon[i],batch['views'][i])
                        #pdb.set_trace()
                        loss =  model.get_loss(view_project)
                        f_score = loss.item()
                        #print('recon loss: {}'.format(loss_1/1000/2))
                        loss = loss+loss_1
                else:
                    view_project = model(batch['views'])
                    
                    #nesum_s = nesum(view_project[0])+nesum(view_project[1])
                    loss =  model.get_loss(view_project) #+ 10*l2_reg_ortho_loss_func(model.encoder[i],method='cor',weight=1)# + l2_reg_ortho_loss_func(model.decoder[i],method='so',weight=1) #- 0.0001*nesum_s
                    #print(epoch,loss.item())
                    #print(nesum_s)
                    #print(nesum(view_project[1]))
                    #pdb.set_trace()
                    f_score = loss.item()
            #pdb.set_trace()
            #nesum_s = (nesum(view_project[0])+nesum(view_project[1]))/2
            #print(nesum(torch.cat(view_project,dim=1)))
            optimzer.zero_grad()

            try:
                loss.backward()
            except:
                pdb.set_trace()
            optimzer.step()
    #pdb.set_trace()
    #pdb.set_trace()

    res = []
    #pdb.set_trace()
    #l2_reg_ortho_loss_func(model.encoder[0],method='debug',weight=1)
    
    with torch.no_grad():
        model.eval()
        

       

        errors_1 = []
        errors_2 = []
        errors_3 = []
        errors_4 = []
        errors_5 = []
        errors_6 = []
        errors_7 = []
    
 
        for i in range(0,len(test[0]),2000):
    #         #pdb.set_trace()
            batch_views = [torch.Tensor(view.astype(np.float64))[i:i+2000,:].cuda() for view in test]
    #         #pdb.set_trace()
            
            view_project = model(batch_views)
            #view_project = []
            #for j in range(len(test)):
            #    view_project.append(view_project_t[j].cpu().numpy())
            
            #MCCA_T = MCCA(latent_dimensions=100)
            #view_project = MCCA_T.fit_transform(view_project)
            #for j in range(len(test)):
            #    view_project[j] = torch.Tensor(view_project[j]).cuda()
            #pdb.set_trace()
            for j in range(len(test)):
                #var,cor = model.feature_quality(view_project[i])
                #print('view {}: var: {} cor:{} '.format(j,var,cor))
                data = view_project[j].cpu()
                #print(nesum(batch_views[j]))
               
                try:
                    res[j].append(data)
                except:
                    res.append([data])
                from sklearn.linear_model import LinearRegression
                from sklearn.metrics import mean_squared_error
            # 初始化线性回归模型
                model_r = LinearRegression()
                x = data.numpy()
                y = batch_views[j].cpu().numpy()
            # 拟合模型
                model_r.fit(x, y)

            # 使用训练好的模型来预测 X
                X_pred = model_r.predict(x)

            # 计算均方误差
                error = mean_squared_error(X_pred, y)
                errors_1.append(error)


                shape = batch_views[j].shape
                #shape = data.shape
                noise = torch.cuda.DoubleTensor(shape)
                torch.randn(shape, out=noise)

                model_r = LinearRegression()

                #y = batch_views[j].cpu().numpy()
                #y= noise.cpu().numpy()
                y = data.numpy()
                x = model.encoder[j](0.7*noise+0.3*batch_views[j]).cpu().numpy()
                #y = noise.cpu().numpy()
                model_r.fit(x, y)

                X_pred = model_r.predict(x)

            # 计算均方误差
                error = mean_squared_error(X_pred, y)
                errors_6.append(error)
                
                errors_2.append(covariance_loss(view_project[j],average=True))
                
                errors_3.append(nesum(view_project[j]))

                errors_4.append(l2_reg_ortho_loss_func(model.encoder[j],method='cor',weight=1,average=True))
                #pdb.set_trace()

                errors_5.append(l2_reg_ortho_loss_func(model.encoder[j],method='nesum',weight=1))
               
                errors_7.append(model.get_loss([model.encoder[j](noise),view_project[j]]))
                
                
    for i in range(len(test)):
        res[i] = torch.cat(res[i],dim=0).numpy()
    
   
    return [res,f_score,sum(errors_1)/len(errors_1),sum(errors_2)/len(errors_2),sum(errors_3)/len(errors_3),sum(errors_4)/len(errors_4),sum(errors_5)/len(errors_5),sum(errors_6)/len(errors_6),sum(errors_7)/len(errors_7)]