import torch

import pdb

import torch.nn as nn


import torch
import torch.nn as nn

import copy

def shallow_copy_model(original_model):
    # 进行模型的浅拷贝
    copied_model = copy.copy(original_model)

    # 遍历浅拷贝后的模型的子模块
    for name, module in copied_model.named_modules():
        # 如果是 Batch Normalization 层，则进行深拷贝
        if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d):
            cloned_module = copy.deepcopy(module)
            setattr(copied_model, name, cloned_module)

    return copied_model

def _demean(views):
    return tuple([view - view.mean(dim=0) for view in views])

def inv_sqrtm(A, eps=1e-9):
    """Compute the inverse square-root of a positive definite matrix."""
    # Perform eigendecomposition of covariance matrix
    U, S, V = torch.svd(A.double())
    # Enforce positive definite by taking a torch max() with eps
    S = torch.max(S, torch.tensor(eps, device=S.device))
    # Calculate inverse square-root
    inv_sqrt_S = torch.diag_embed(torch.pow(S, -0.5))
    # Calculate inverse square-root matrix
    B = torch.matmul(torch.matmul(U, inv_sqrt_S), V.transpose(-1, -2))
    return B.half()



def freeze_bn_layers(model):
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d):
            module.eval()
            # for param in module.parameters():
            #     param.requires_grad_(False)


def unfreeze_bn_layers(model):
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d):
            module.train()
            # for param in module.parameters():
            #     param.requires_grad_(True)

def generate_gaussian_noise_with_rank(m,n,k=50):
    # 生成随机的高斯白噪声
    shape = (m,n)
    
    noise = torch.randn(shape)
    if k==n:
        return noise
    # 对噪声矩阵进行奇异值分解
    u, s, v = torch.svd(noise)
    
    # 将奇异值矩阵的前k个值保留，其余值置为0
    s[k:] = 0
    
    # 重构秩为k的矩阵
    reconstructed_noise = torch.matmul(torch.matmul(u, torch.diag_embed(s)), v.transpose(-2, -1))
    
    return reconstructed_noise


def cca_loss(views):
        n = views[0].shape[0]
        # Subtract the mean from each output
        views = _demean(views)

        # Concatenate all views and from this get the cross-covariance matrix
        all_views = torch.cat(views, dim=1)
        C = all_views.T @ all_views / (n - 1)

        # Get the block covariance matrix placing Xi^TX_i on the diagonal
        D = torch.block_diag(
            *[
                (1 - 0) * m.T @ m / (n - 1)
                + 0 * torch.eye(m.shape[1], device=m.device)
                for i, m in enumerate(views)
            ]
        )

       # C = (C - torch.block_diag(*[view.T @ view / (n - 1) for view in views]) + D)  #.float()

        R = inv_sqrtm(D, eps=1e-3)

        # In MCCA our eigenvalue problem Cv = lambda Dv
        C_whitened = R @ C @ R.T
        return torch.sqrt(torch.trace(C_whitened@C_whitened.T))
        # #pdb.set_trace()
        # eigvals = torch.linalg.eigvalsh(C_whitened.double())
        

        # # Sort eigenvalues so lviewest first
        # idx = torch.argsort(eigvals, descending=True)

        # eigvals = eigvals[idx]

        # # leaky relu encourages the gradient to be driven by positively correlated dimensions while also encouraging
        # # dimensions associated with spurious negative correlations to become more positive
        # eigvals = torch.nn.LeakyReLU()(eigvals[torch.gt(eigvals, 0)] - 1)

        # corr = eigvals.sum()

        # return -corr.half()



# def generate_low_rank_matrix(m, n, k=50):

    
    

#     shape = (m,n)
#         #integers = np.arange(n + 1)  # 生成0到n的整数数组
#     try:
        
#         A = torch.randn(shape)
#         if k==n:
#             return A

#         B = torch.zeros((m,n))
        
    

#         for j in range(n):
#             B[:,j] = A[:,j%k]
#         return B
#     except:
#             pdb.set_trace()

def generate_low_rank_matrix(m, n, k=50,source =None):

    if source is  None:
        shape = (m,n)
        A = torch.randn(shape)
        if k==n:
            return A

        B = torch.zeros((m,n))
        
    

        for j in range(n):
            B[:,j] = A[:,j%k]
        random_integers = torch.randperm(n)
        return B[:,random_integers]
    else:

        shape = (m,n)
        #integers = np.arange(n + 1)  # 生成0到n的整数数组
        try:
            random_integers = torch.randperm(n)[:k]
            A = source[:,random_integers]
            B = torch.zeros((m,n),dtype=torch.double).cuda() 
    

            for j in range(n):
                B[:,j] = A[:,j%k]
            return B
        except:
            pdb.set_trace()
    return B






def set_bn_requires_grad(model, requires_grad):
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d):
            module.requires_grad_(requires_grad)



# def noise_regular_loss_func(encoder,input_img ,features,params):
#     bs = features.shape[0]
#     output_dim = features.shape[1]
#     r = torch.eye(bs, output_dim).to(features.device)
#     # feature_loss = torch.abs(output_dim+cca_loss([(features+r*1e-4),features]))

#     # try:
#     #     loss =  eval(params['feature_loss']) *feature_loss  # self_1 : no_trans  self_2 : trans
#     # except:
#     #     loss =   params['feature_loss']*feature_loss

#     # return loss




#     bs = features.shape[0]
#     output_dim = features.shape[1]

#     input_img_shape = input_img.shape

#     input_img_flat = input_img.reshape(input_img.shape[0],-1)
#     input_img_flat_dim = input_img_flat.shape[1]
#     r = torch.eye(bs, output_dim).to(features.device)

#     #encoder_copy = shallow_copy_model(encoder)
    
#     if params['regular_method']== 'self_regular':
#         k = output_dim//4
#         with torch.no_grad():
#             noise_tensor = generate_low_rank_matrix(m=bs,n=input_img_flat_dim,k =k).to(features.device)
#             #print(features.dtype,noise_tensor.deype)
#             #exit(0)
#             noise_tensor_cor = cca_loss([noise_tensor[:,:output_dim]+r*1e-4,noise_tensor[:,:output_dim]])
            
#         # # #freeze_bn_layers(encoder)
#         noise_project = encoder(noise_tensor.reshape(input_img_shape))

#         noise_loss = torch.abs( cca_loss([noise_project+r*1e-4,noise_project]) - noise_tensor_cor) 
#         feature_loss = torch.abs(output_dim  - cca_loss([features+r*1e-4,features]))
#         #pdb.set_trace()
#         try:
#             loss =  eval(params['feature_loss']) *feature_loss + eval(params['noise_loss']) * noise_loss 
#         except:
#             loss =   params['feature_loss']*feature_loss + params['noise_loss'] * noise_loss
#         #print(noise_loss.item(),feature_loss.item(),loss.item())
#         #pdb.set_trace()
#         #unfreeze_bn_layers(encoder)
#     else:
#         with torch.no_grad():
#             #encoder.eval()
#             noise_tensor = generate_low_rank_matrix(m=bs,n=input_img_flat_dim,k =input_img_flat_dim).to(features.device)
#             #pdb.set_trace()
#             cross_cor = cca_loss([noise_tensor[:,:output_dim].t(),input_img_flat[:,:output_dim].t()])
#         freeze_bn_layers(encoder)
#         noise_project = encoder(noise_tensor.reshape(input_img_shape))
#         #pdb.set_trace()
        
#         cross_cor_project = cca_loss([noise_project.t(),features.t()])           # cor_1 : outdim+ trans cor_2 : full+ trans

#         #feature_loss = torch.abs(output_dim+cca_loss([features+r*1e-4,features]))  `   `
#         #pdb.set_trace()
#         unfreeze_bn_layers(encoder)
#         try:
            
#             loss = eval(params['cor_loss']) * torch.abs(cross_cor_project-cross_cor)
#         except:
#             loss = params['cor_loss'] * torch.abs(cross_cor_project-cross_cor)
#         #unfreeze_bn_layers(encoder)
#         # try:
#         #     loss = loss+ eval(params['feature_loss']) *feature_loss 
#         # except:
#         #     loss = loss+  params['feature_loss']*feature_loss
#     return loss
        



def noise_regular_loss_func(encoder,input_img ,features,params):
    bs = features.shape[0]
    output_dim = features.shape[1]
    r = torch.eye(bs, output_dim).to(features.device)
    # feature_loss = torch.abs(output_dim+cca_loss([(features+r*1e-4),features]))

    # try:
    #     loss =  eval(params['feature_loss']) *feature_loss  # self_1 : no_trans  self_2 : trans
    # except:
    #     loss =   params['feature_loss']*feature_loss

    # return loss




    bs = features.shape[0]
    output_dim = features.shape[1]

    input_img_shape = input_img.shape

    input_img_flat = input_img.reshape(input_img.shape[0],-1)
    input_img_flat_dim = input_img_flat.shape[1]
    r = torch.eye(bs, output_dim).to(features.device)

    #encoder_copy = shallow_copy_model(encoder)
    
    with torch.no_grad():
        encoder.eval()
        noise_tensor = generate_low_rank_matrix(m=bs,n=input_img_flat_dim,k =input_img_flat_dim).to(features.device)
            #pdb.set_trace()
        cross_cor = cca_loss([noise_tensor[:,:output_dim].t(),input_img_flat[:,:output_dim].t()])
    
        noise_project = encoder(noise_tensor.reshape(input_img_shape))
        #pdb.set_trace()
    encoder.train()    
    cross_cor_project = cca_loss([noise_project.t(),features.t()])           # cor_1 : outdim+ trans cor_2 : full+ trans

        #feature_loss = torch.abs(output_dim+cca_loss([features+r*1e-4,features]))  `   `
        #pdb.set_trace()
        #unfreeze_bn_layers(encoder)
    try:
            
        loss = eval(params['cor_loss']) * torch.abs(cross_cor_project-cross_cor)
    except:
        loss = params['cor_loss'] * torch.abs(cross_cor_project-cross_cor)
        #unfreeze_bn_layers(encoder)
        # try:
        #     loss = loss+ eval(params['feature_loss']) *feature_loss 
        # except:
        #     loss = loss+  params['feature_loss']*feature_loss
    return loss
        



