import numpy as np
import torch

import cca_zoo
from cca_zoo.data.deep import get_dataloaders
import pytorch_lightning as pl
import numpy as np
import pdb
from .numpydataset import NumpyDataset
from .model_AE import DCCA_Noise_Norm_M 
from .utils import Initialize_Seed,Distance_Correlation
import math
import torch.optim.lr_scheduler as lr_scheduler
from sklearn.decomposition import FastICA
from .utils import cca_loss
from tqdm import tqdm
import copy
import torch.nn.functional as F


    
def covariance_loss(z1: torch.Tensor,average=False) -> torch.Tensor:
   
    N, D = z1.size()
    
   # z1 = z1 - z1.mean(dim=0)
  
    #cov_z1 = (z1.T @ z1)  / (N - 1)
    cov_z1 = torch.corrcoef(z1.T)
    #return torch.trace(cov_z1)/math.sqrt(D*D)
    diag = torch.eye(D, device=z1.device)
    #pdb.set_trace()
    if average:
        cov_loss = torch.abs(cov_z1[~diag.bool()]).mean() #/(D*D-D)
    else:
        cov_loss = torch.norm(cov_z1[~diag.bool()])**2
    
    return cov_loss



def l2_reg_ortho_loss_func(mdl,device='cuda',weight = 1e-2,method='risp',average=False):
    #return 1
    l2_reg = []
        
    reg=None
    #pdb.set_trace()
    #print('nmsl')
    #return 2
    for W in mdl.parameters():  # output * input
        #pdb.set_trace()
                #exit(0)
        if W.ndimension() < 2:
            continue
        else:
                    #print(W.shape)
                    #pdb.set_trace()
            if method=='cor' :
                reg =  covariance_loss(W.t(),average=average)
            elif method =='nesum':
                reg =  nesum(W.t())
            elif method == 'debug':
                pdb.set_trace()
            
            l2_reg.append(reg)
      
    return sum(l2_reg)/len(l2_reg)


def nesum(tensor):
    z1  = tensor
          
  
    cov_z1 = torch.corrcoef(z1.T)
            
    eigenvalues = torch.linalg.eigvalsh(cov_z1)
    #U, eigenvalues, V = torch.linalg.svd(tensor, full_matrices=False)

            # 计算特征值最大值
    eigenvalue_max = torch.max(eigenvalues)

    #eigenvalue_min = torch.min(eigenvalues)

            # 计算特征值之和
    eigenvalue_sum = torch.mean(eigenvalues/eigenvalue_max)

            # 计算比值
    #ratio = eigenvalue_max / eigenvalue_sum
    #pdb.set_trace()
    return eigenvalue_sum

def MyDCCA_full_fit_transform(train,test,dim=100,out_dim=100,epochs=50,num_views=2,noise='normal',loss_name='cca',approach=True,recon=False,private=False,lr=1e-4,a=5,linear=False):
    if loss_name!='tcca' or True:
        torch.set_default_dtype(torch.float64)
    else:
        torch.set_default_dtype(torch.float32)
 
    in_dims = [view.shape[1] for view in train]
    dataset = NumpyDataset(train, labels=None,loss_name=loss_name)
    num_views = len(in_dims)
  
    loss_fn = torch.nn.MSELoss(reduction='mean')
    if loss_name=='tcca':
        out_dim=25
        lr = 20*1e-5
        model = DCCA_Noise_Norm_M(in_dims=in_dims,out_dim=out_dim,view_num=num_views,loss_name=loss_name,recon=recon,private=private).cuda()
    else:
        out_dim=200
        lr = lr
        model = DCCA_Noise_Norm_M(in_dims=in_dims,out_dim=out_dim,view_num=num_views,loss_name=loss_name,recon=recon,private=private,linear=linear,noise=noise).cuda()
    #pdb.set_trace()
    optimzer = torch.optim.Adam(model.parameters(), lr=lr)
   
    # 
    encoder = []
    encoder = copy.deepcopy(model.encoder)
   
    model.train()
    f_score = 0
 
    dataloader = get_dataloaders(dataset=dataset,batch_size=2000,shuffle_train=False,drop_last=False,num_workers=1)
    gts = []
    pre_loss = 0
    for epoch in tqdm(range(epochs)):
       
        for batch_idx, batch in enumerate(dataloader):

            #torch.cuda.empty_cache()
            batch['views'] = [view.cuda() for view in batch['views']]
            #shape= batch['views'][0].shape
            noise_tensor =  []
         
            #noise_tensor_fuse = {}
           # pdb.set_trace()
            for i in range(len(batch['views'])):
                shape = batch['views'][i].shape
                if loss_name!='tcca' or True:
                #    shape = (300,shape[1],)
                    noise_add = torch.cuda.DoubleTensor(shape)
                    torch.randn(shape, out=noise_add)
                   
                    noise_tensor.append(noise_add)


                   
          
            if noise =='normal':
                view_project,noise_project = model(batch['views'],noise_tensor)
              
                loss_1 = torch.zeros(1).cuda()
              
                try:
                    pre_loss = loss_2.item()
                except:
                    pass
                loss_2 =   model.get_loss(view_project)
                f_score = loss_2.item()
                
               
                if approach:
                    for i in range(len(batch['views'])):
                  
                        
                     
                       
                        #gt = gts[i]
                        with torch.no_grad():
                            #gt =  model.get_loss([encoder[i](noise_tensor[i]),encoder[i](batch['views'][i])])
                            gt = model.get_loss([noise_tensor[i][:,:out_dim],batch['views'][i][:,:out_dim]])
                        here = model.get_loss([noise_project[i],view_project[i]])
                      
                        k_a = a
                        #k_a=1
                        loss_1 = loss_1+k_a*torch.abs(gt-here)
                       
                  
                        print(epoch,a,gt.item(),here.item(),loss_1.item(),loss_2.item())
                     
                  
                    loss = loss_2 + loss_1

              
                
                else:
                    loss = 200*loss_1+loss_2
            elif noise =='none':
                loss_1 = torch.zeros(1).cuda()
                if recon:
                    if private:
                        view_project,view_recon = model(batch['views'],recon=recon)
                        for i in range(len(batch['views'])):
                            loss_1 =  loss_1+ 0.01*loss_fn(view_recon[i],batch['views'][i])
                        loss =  model.get_loss([view[:,:out_dim] for view in view_project])
                        f_score = loss.item()
                        loss = loss+loss_1
                    else:
                        view_project,view_recon = model(batch['views'],recon=recon)
                        for i in range(len(batch['views'])):
                            loss_1 =  loss_1+ 0.01*loss_fn(view_recon[i],batch['views'][i])
                        #pdb.set_trace()
                        loss =  model.get_loss(view_project)
                        f_score = loss.item()
                        loss = loss+loss_1
                else:
                    view_project = model(batch['views'])
                    loss =  model.get_loss(view_project)
                    #print(epoch,loss.item())
                    f_score = loss.item()
               
            optimzer.zero_grad()
            try:
                loss.backward()
            except:
                pdb.set_trace()
          
            optimzer.step()
    #pdb.set_trace()
    #pdb.set_trace()
    res = []
    with torch.no_grad():
        model.eval()
        

       

        errors_1 = []
        errors_2 = []
        errors_3 = []
        errors_4 = []
        errors_5 = []
        errors_6 = []
        errors_7 = []
    
 
        for i in range(0,len(test[0]),2000):
    #         #pdb.set_trace()
            batch_views = [torch.Tensor(view.astype(np.float64))[i:i+2000,:].cuda() for view in test]
    #         #pdb.set_trace()
            
            view_project = model(batch_views)
            #view_project = []
            #for j in range(len(test)):
            #    view_project.append(view_project_t[j].cpu().numpy())
            
            #MCCA_T = MCCA(latent_dimensions=100)
            #view_project = MCCA_T.fit_transform(view_project)
            #for j in range(len(test)):
            #    view_project[j] = torch.Tensor(view_project[j]).cuda()
            #pdb.set_trace()
            for j in range(len(test)):
                #var,cor = model.feature_quality(view_project[i])
                #print('view {}: var: {} cor:{} '.format(j,var,cor))
                data = view_project[j].cpu()
                #print(nesum(batch_views[j]))
               
                try:
                    res[j].append(data)
                except:
                    res.append([data])
                from sklearn.linear_model import LinearRegression
                from sklearn.metrics import mean_squared_error
            # 初始化线性回归模型
                model_r = LinearRegression()
                x = data.numpy()
                y = batch_views[j].cpu().numpy()
            # 拟合模型
                model_r.fit(x, y)

            # 使用训练好的模型来预测 X
                X_pred = model_r.predict(x)

            # 计算均方误差
                error = mean_squared_error(X_pred, y)
                errors_1.append(error)


                shape = batch_views[j].shape
                #shape = data.shape
                noise = torch.cuda.DoubleTensor(shape)
                torch.randn(shape, out=noise)

                model_r = LinearRegression()

                #y = batch_views[j].cpu().numpy()
                #y= noise.cpu().numpy()
                y = data.numpy()
                x = model.encoder[j](0.7*noise+0.3*batch_views[j]).cpu().numpy()
                #y = noise.cpu().numpy()
                model_r.fit(x, y)

                X_pred = model_r.predict(x)

            # 计算均方误差
                error = mean_squared_error(X_pred, y)
                errors_6.append(error)
                
                errors_2.append(covariance_loss(view_project[j],average=True))
                
                errors_3.append(nesum(view_project[j]))

                errors_4.append(l2_reg_ortho_loss_func(model.encoder[j],method='cor',weight=1,average=True))
                #pdb.set_trace()

                errors_5.append(l2_reg_ortho_loss_func(model.encoder[j],method='nesum',weight=1))
               
                errors_7.append(model.get_loss([model.encoder[j](noise),view_project[j]]))
                
                
    for i in range(len(test)):
        res[i] = torch.cat(res[i],dim=0).numpy()
    
   
    return [res,f_score,sum(errors_1)/len(errors_1),sum(errors_2)/len(errors_2),sum(errors_3)/len(errors_3),sum(errors_4)/len(errors_4),sum(errors_5)/len(errors_5),sum(errors_6)/len(errors_6),sum(errors_7)/len(errors_7)]