import numpy as np
import torch

import cca_zoo
from cca_zoo.data.deep import get_dataloaders
import pytorch_lightning as pl
import numpy as np
import pdb
from .numpydataset import NumpyDataset
from .model_AE import DCCA_Noise_Norm_M 
from .utils import Initialize_Seed,Distance_Correlation
import math
import torch.optim.lr_scheduler as lr_scheduler
from sklearn.decomposition import FastICA
from .utils import cca_loss
from tqdm import tqdm
import copy
import torch.nn.functional as F
from cca_zoo.deepmodels import architectures

    
def covariance_loss(z1: torch.Tensor,average=False) -> torch.Tensor:
   
    N, D = z1.size()
    
   # z1 = z1 - z1.mean(dim=0)
  
    #cov_z1 = (z1.T @ z1)  / (N - 1)
    cov_z1 = torch.corrcoef(z1.T)
    #return torch.trace(cov_z1)/math.sqrt(D*D)
    diag = torch.eye(D, device=z1.device)
    #pdb.set_trace()
    if average:
        cov_loss = torch.abs(cov_z1[~diag.bool()]).mean() #/(D*D-D)
    else:
        cov_loss = torch.norm(cov_z1[~diag.bool()])**2
    
    return cov_loss



def l2_reg_ortho_loss_func(mdl,device='cuda',weight = 1e-2,method='risp',average=False):
    #return 1
    l2_reg = []
        
    reg=None
    #pdb.set_trace()
    #print('nmsl')
    #return 2
    for W in mdl.parameters():  # output * input
        #pdb.set_trace()
                #exit(0)
        if W.ndimension() < 2:
            continue
        else:
                    #print(W.shape)
                    #pdb.set_trace()
            if method=='cor' :
                reg =  covariance_loss(W.t(),average=average)
            elif method =='nesum':
                reg =  nesum(W.t())
            elif method == 'debug':
                pdb.set_trace()
            
            l2_reg.append(reg)
      
    return sum(l2_reg)/len(l2_reg)


def nesum(tensor):
    z1  = tensor
          
  
    cov_z1 = torch.corrcoef(z1.T)
            
    eigenvalues = torch.linalg.eigvalsh(cov_z1)
    #U, eigenvalues, V = torch.linalg.svd(tensor, full_matrices=False)

            # 计算特征值最大值
    eigenvalue_max = torch.max(eigenvalues)

    #eigenvalue_min = torch.min(eigenvalues)

            # 计算特征值之和
    eigenvalue_sum = torch.mean(eigenvalues/eigenvalue_max)

            # 计算比值
    #ratio = eigenvalue_max / eigenvalue_sum
    #pdb.set_trace()
    return eigenvalue_sum

def mean_variance_normalize(input_tensor,sigma_squared=0.1):
    """
    对给定的N*dim tensor进行均值方差归一化
    """
    # 计算每个特征的均值和标准差
    #计算当前tensor的均值
    mean = torch.mean(input_tensor, dim=0, keepdim=True)
    
    # 减去均值以得到0均值
    zero_mean_tensor = input_tensor - mean
    
    # 计算当前tensor的标准差
    std = torch.std(zero_mean_tensor, dim=0, keepdim=True)+1e-10
    
    
    # 根据指定的方差调整tensor的标准差
    desired_std = sigma_squared
    normalized_tensor = zero_mean_tensor * (desired_std / std)
    
    return normalized_tensor

def MyDCCA_full_fit_transform(train,test,dim=100,out_dim=200,epochs=50,num_views=2,noise='normal',loss_name='cca',approach=True,recon=False,private=False,a=1, lr = 10*1e-5,linear=False,regular=None):
 
    torch.set_default_dtype(torch.float64)
  
   
    print(len(train[0]))
    print(len(test[0]))
    print(lr)

 
    in_dims = [view.shape[1] for view in train]
    #print(in_dims)
    #exit(0)
    dataset = NumpyDataset(train, labels=None,loss_name=loss_name)
    num_views = len(in_dims)
  
    loss_fn = torch.nn.MSELoss(reduction='mean')
   
    #out_dim = 200
    model = DCCA_Noise_Norm_M(in_dims=in_dims,out_dim=out_dim,view_num=num_views,loss_name=loss_name,recon=recon,private=private,linear=linear,noise=noise).cuda()
    #pdb.set_trace()
    optimzer = torch.optim.Adam(model.parameters(), lr=lr)
   
    encoder = []
    encoder_ori = copy.deepcopy(model.encoder)
    encoder_o = torch.nn.ModuleList([torch.nn.Linear(in_dims[i],out_dim) for i in range(num_views)]).cuda()
    for i in range(len(encoder)):
        torch.nn.init.orthogonal_(encoder_o[i].weight)
        # 也可以为bias赋予一些初始值，比如0
        #if m.bias is not None:
        torch.nn.init.constant_(encoder_o[i].bias, 0)
    encoder = torch.nn.ModuleList([torch.nn.Linear(in_dims[i],out_dim) for i in range(num_views)]).cuda()
    model.train()
    f_score = 0
    errors_6 = []
    bs = min(len(train[0]),2000)
    #bs = len(train[0])
    dataloader = get_dataloaders(dataset=dataset,batch_size=bs,shuffle_train=False,drop_last=True,num_workers=1)
    gts = []
    pre_loss = 0
    #l2_reg_ortho_loss_func(model.encoder[0],method='debug',weight=1)
    for epoch in tqdm(range(epochs)):
       
        for batch_idx, batch in enumerate(dataloader):

           
            batch['views'] = [view.cuda() for view in batch['views']]
         
            noise_tensor =  []
          
            
      
            for i in range(len(batch['views'])):
                shape = batch['views'][i].shape
                
                #    shape = (300,shape[1],)
                noise_add = torch.cuda.DoubleTensor(shape)
                torch.randn(shape, out=noise_add)
                  
                noise_tensor.append(noise_add)

            if noise =='normal':
                view_project,noise_project = model(batch['views'],noise_tensor)
            
                loss_1 = torch.zeros(1).cuda()
                #loss_3 = torch.zeros(1).cuda()
                try:
                    pre_loss = loss_2.item()
                except:
                    pass
                loss_2 =   model.get_loss(view_project)
                f_score = loss_2.item()
                
                #print(epoch,loss_1.item(),loss_2.item())
               
                for i in range(len(batch['views'])):
                
                    
                    
                    # # # #gt = gts[i]
                    std = 0.01
                    with torch.no_grad():
                        #std = batch['views'][i][:,:out_dim].std(0).mean(0)
                        gt = model.get_loss([noise_tensor[i][:,:out_dim],batch['views'][i][:,:out_dim]])
                        #gt =  model.get_loss([encoder_ori[i](noise_tensor[i]),encoder_ori[i](batch['views'][i])])
                        #gt =  model.get_loss([noise_tensor[i].T,batch['views'][i].T])
                    #pdb.set_trace()
                    #here = model.get_loss([mean_variance_normalize(noise_project[i],std),mean_variance_normalize(view_project[i],std)])
                    here = model.get_loss([noise_project[i],view_project[i]])
                    #pdb.set_trace()
                    # CUB: 1.5 (0.921) 
                    # cal: 45 (0.611)
                    #a= 0
                    loss_1 = loss_1+a*torch.abs(gt-here)
                    #pdb.set_trace()
                #loss_1 = loss_1/len(batch['views'])
                #loss_1 = loss_1+1e-5
                
                    print(epoch,a,gt.item(),here.item(),loss_1.item(),loss_2.item())
                    
                
                loss = loss_2 + loss_1
                 
                
               
            elif noise =='none':
                loss_1 = torch.zeros(1).cuda()
                if recon:
                    if private:
                        view_project,view_recon = model(batch['views'],recon=recon)
                        for i in range(len(batch['views'])):
                            loss_1 =  loss_1+ 0.01*loss_fn(view_recon[i],batch['views'][i])
                        loss =  model.get_loss([view[:,:out_dim] for view in view_project])
                        f_score = loss.item()
                        loss = loss+loss_1
                    else:
                        if regular:
                            view_project= model(batch['views'])

                            #view_project,view_recon = model([0.3*batch['views'][k]+ 0.7*noise_tensor[k] for k in range(len(batch['views']))],recon=recon)
                            for i in range(len(batch['views'])):
                                # import pdb
                                # pdb.set_trace()
                                #view_recon_here = model.decoder[i](model.encoder[i](0.1*batch['views'][i]+ 0.9*noise_tensor[i]))
                                #r_here = apply_noise_to_model_input(model.encoder[i],batch['views'][i])
                                r_here = view_project[i]
                                view_recon_here = model.decoder[i](r_here) #
                                #view_project.append(r_here)
                                # if epoch>98:
                                #pdb.set_trace()
                                #decov_loss = covariance_loss(model.encoder[i].layers(batch['views'][i]))
                                decov_loss =  25*l2_reg_ortho_loss_func(model.encoder[i],method='so',weight=1) #+ l2_reg_ortho_loss_func(model.decoder[i].layers[0],method='so',weight=1)
                                #     pdb.set_trace()
                                #decov_loss = covariance_loss(r_here)
                               # va_loss = variance_loss(r_here)
                                #print(decov_loss,variance_loss(r_here))
                                print(decov_loss)
                                loss_1 =  loss_1+ 0.01* loss_fn(view_recon_here,batch['views'][i]) - 1*decov_loss# +0.1*va_loss 

                        else:
                            view_project,view_recon = model(batch['views'],recon=recon)
                            for i in range(len(batch['views'])):
                                loss_1 =  loss_1+ 0.01* loss_fn(view_recon[i],batch['views'][i])
                        #pdb.set_trace()
                        loss =  model.get_loss(view_project)
                        f_score = loss.item()
                        
                        loss = loss+loss_1
                else:
                    view_project = model(batch['views'])
                    #l1 = l2_reg_ortho_loss_func(model.encoder[0],method='so',weight=1) + l2_reg_ortho_loss_func(model.encoder[1],method='so',weight=1)
                    loss =  model.get_loss(view_project) #- l2_reg_ortho_loss_func(model.encoder[0],method='nesum',weight=1) - l2_reg_ortho_loss_func(model.encoder[1],method='nesum',weight=1) #+l1*10
                    #print(epoch,loss.item(),l1.item())
                    #nesum_s = nesum(view_project[0])+nesum(view_project[1])
                    #loss =  loss #- 0.0001*nesum_s
                    #print(epoch,loss.item())
                    #print(nesum_s)
                    loss =  loss
                    f_score = loss.item()
                
            #print(l2_reg_ortho_loss_func(model.encoder[0],method='nesum',weight=1))
            optimzer.zero_grad()
            try:
                loss.backward()
            except:
                pass
            # pdb.set_trace()
            # for j in range(len(test)):
            #     errors_6.append(l2_reg_ortho_loss_func(weight,method='gradient_norm',weight=1))
            optimzer.step()
                #errors_6.append(l2_reg_ortho_loss_func(model.encoder[1],method='gradient_norm',weight=1))
            #pdb.set_trace()
            #print(nesum(view_project[0]))
            # 计算特征值
            
            #print(torch.linalg.matrix_rank(view_project[0]))
            # apply_noise_to_model_input(model.encoder[i],input_tensor=batch['views'][i])
        # vars = []
        # covs =[]
        # for module in model.encoder[0].modules():
        #     if isinstance(module, torch.nn.Linear):  # 检查模块是否是nn.Linear类型
        #         w = module.weight.detach().t()
        #         vars.append(variance_loss(w))
        #         covs.append(nesum(w))
        # print("var:{} cor:{}".format(sum(vars)/len(vars),sum(covs)/len(covs)))

        # 提取权重
        #weights = module.weight.data
        #w= model.encoder[0].weight.detach().t()
        #pdb.set_trace()
        #print(torch.norm(w.t()@w-torch.eye(w.size(1),w.size(1)).cuda()))
        #print('var:{} cov: {}'.format(variance_loss(w),covariance_loss(w)))
        #pdb.set_trace()

    res = []
    #pdb.set_trace()
  
    
    with torch.no_grad():
        model.eval()
        

    

        errors_1 = []
        errors_2 = []
        errors_3 = []
        errors_4 = []
        errors_5 = []
        errors_6 = []
        errors_7 = []
    

        for i in range(0,len(test[0]),2000):
    #         #pdb.set_trace()
            batch_views = [torch.Tensor(view.astype(np.float64))[i:i+2000,:].cuda() for view in test]
    #         #pdb.set_trace()
            
            view_project = model(batch_views)
            #view_project = []
            #for j in range(len(test)):
            #    view_project.append(view_project_t[j].cpu().numpy())
            
            #MCCA_T = MCCA(latent_dimensions=100)
            #view_project = MCCA_T.fit_transform(view_project)
            #for j in range(len(test)):
            #    view_project[j] = torch.Tensor(view_project[j]).cuda()
            #pdb.set_trace()
            for j in range(len(test)):
                #var,cor = model.feature_quality(view_project[i])
                #print('view {}: var: {} cor:{} '.format(j,var,cor))
                data = view_project[j].cpu()
                #print(nesum(batch_views[j]))
            
                try:
                    res[j].append(data)
                except:
                    res.append([data])
                from sklearn.linear_model import LinearRegression
                from sklearn.metrics import mean_squared_error
            # 初始化线性回归模型
                model_r = LinearRegression()
                x = data.numpy()
                y = batch_views[j].cpu().numpy()
            # 拟合模型
                model_r.fit(x, y)

            # 使用训练好的模型来预测 X
                X_pred = model_r.predict(x)

            # 计算均方误差
                error = mean_squared_error(X_pred, y)
                errors_1.append(error)


                shape = batch_views[j].shape
                #shape = data.shape
                noise = torch.cuda.DoubleTensor(shape)
                torch.randn(shape, out=noise)

                model_r = LinearRegression()

                #y = batch_views[j].cpu().numpy()
                #y= noise.cpu().numpy()
                y = data.numpy()
                x = model.encoder[j](0.7*noise+0.3*batch_views[j]).cpu().numpy()
                #y = noise.cpu().numpy()
                model_r.fit(x, y)

                X_pred = model_r.predict(x)

            # 计算均方误差
                error = mean_squared_error(X_pred, y)
                errors_6.append(error)
                
                errors_2.append(covariance_loss(view_project[j],average=True))
                
                errors_3.append(nesum(view_project[j]))

                errors_4.append(l2_reg_ortho_loss_func(model.encoder[j],method='cor',weight=1,average=True))
                #pdb.set_trace()

                errors_5.append(l2_reg_ortho_loss_func(model.encoder[j],method='nesum',weight=1))
            
                errors_7.append(model.get_loss([model.encoder[j](noise),view_project[j]]))
                #pdb.set_trace()
    #pdb.set_trace()          
    for i in range(len(test)):
        res[i] = torch.cat(res[i],dim=0).numpy()
    

    return [res,f_score,sum(errors_1)/len(errors_1),sum(errors_2)/len(errors_2),sum(errors_3)/len(errors_3),sum(errors_4)/len(errors_4),sum(errors_5)/len(errors_5),sum(errors_6)/len(errors_6),sum(errors_7)/len(errors_7)]