from torch.autograd import Variable
import torch

import torch
from torch.optim import lr_scheduler
import copy
from torch import cuda, nn, optim
from tqdm import tqdm, trange
import numpy
from torch.nn.functional import normalize
from torch.autograd import Variable
from torch import cuda, nn, optim
import torch.nn.functional as F

import pdb


def calculate_lp_norm(x, p):
    """
    计算输入张量的Lp范数。

    参数：
    x: 输入张量
    p: 范数类型，可以是 0, 1, 2 或 float('inf')

    返回值：
    norm: 计算得到的Lp范数值
    """
    if p == 0:
        non_zero_elements = torch.nonzero(x).size(0)
        norm = non_zero_elements
    elif p == float('inf'):
        norm = torch.max(torch.abs(x))
    else:
        #pdb.set_trace()
        norm = torch.pow(torch.sum(torch.pow(torch.abs(x), p)), 1.0/p)


    return norm


import torch
import torch.nn.functional as F
from solo.utils.misc import gather




def variance_loss(z1: torch.Tensor) -> torch.Tensor:
    """Computes variance loss given batch of projected features z1 from view 1 and
    projected features z2 from view 2.

    Args:
        z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
        z2 (torch.Tensor): NxD Tensor containing projected features from view 2.

    Returns:
        torch.Tensor: variance regularization loss.
    """

    eps = 1e-4
    std_z1 = torch.sqrt(z1.var(dim=0) + eps)
 
    std_loss = torch.mean(F.relu(1 - std_z1)) 
    return std_loss


def covariance_loss(z1: torch.Tensor) -> torch.Tensor:
    """Computes covariance loss given batch of projected features z1 from view 1 and
    projected features z2 from view 2.

    Args:
        z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
        z2 (torch.Tensor): NxD Tensor containing projected features from view 2.

    Returns:
        torch.Tensor: covariance regularization loss.
    """

    N, D = z1.size()

    z1 = z1 - z1.mean(dim=0)
  
    cov_z1 = (z1.T @ z1) / (N - 1)

    diag = torch.eye(D, device=z1.device)
    cov_loss = cov_z1[~diag.bool()].pow_(2).sum()
    return cov_loss


# def feature_loss_fuc(feature,var_loss_weight=1,cov_loss_weight=0.1):

#         z = gather(feature)
#         var_loss = variance_loss(z)
#         cov_loss = covariance_loss(z)

#         return var_loss_weight * var_loss + cov_loss_weight * cov_loss


# def feature_loss_fuc(feature):
#         ident = torch.eye(feature.size(1),feature.size(1),device=feature.device, dtype=feature.dtype)
#         feature_norm = feature/(feature**2).sum(0, keepdim=True).sqrt()
#         w_tmp = (ident-feature_norm.T@feature_norm)
#         b_k = torch.rand(w_tmp.shape[1],1,device=feature.device, dtype=feature.dtype)

#         v1 = torch.matmul(w_tmp, b_k)
#         norm1 = torch.norm(v1,2)
#         v2 = torch.div(v1,norm1)
#         v3 = torch.matmul(w_tmp,v2)

#         #pdb.set_trace()
#         return (torch.norm(v3,2))
       
        # z = gather(feature)
        # var_loss = variance_loss(z)
        # cov_loss = covariance_loss(z)

        # return var_loss_weight * var_loss + cov_loss_weight * cov_loss
# def feature_loss_fuc(features,idx=256):
#         Z1,Z2 = features[0].detach(),features[1].detach()
#         K = (Z1+Z2)/2
#         return min_nonzero_eig(Z1,K,idx)

def nesum(tensor):
    z1  = tensor
          
  
    cov_z1 = torch.corrcoef(z1.T)
    #pdb.set_trace()
    eigenvalues = torch.linalg.eigvalsh(cov_z1.to(torch.float32))
    #U, eigenvalues, V = torch.linalg.svd(tensor, full_matrices=False)

            # 计算特征值最大值
    eigenvalue_max = torch.max(eigenvalues)

    #eigenvalue_min = torch.min(eigenvalues)

            # 计算特征值之和
    eigenvalue_sum = torch.mean(eigenvalues/eigenvalue_max)

            # 计算比值
    #ratio = eigenvalue_max / eigenvalue_sum
    #pdb.set_trace()
    return eigenvalue_sum


def feature_loss_fuc(features):
        Z1,Z2 = features[0].detach(),features[1].detach()
        K = (Z1+Z2)/2
        #pdb.set_trace()
        return nesum(K).item()
        # K_centered = K - K.mean(dim=0)
    
        # # 计算协方差矩阵X和Y
        # X = torch.mm(K_centered.T, K_centered) / (K.size(0) - 1)
        # eigenvalues = torch.linalg.eigvalsh(X.to(torch.float32))
        # eigenvalues = torch.where(eigenvalues > 1e-5, eigenvalues, 1e-5*torch.ones_like(eigenvalues))
        # return torch.sum(eigenvalues).item()
         
        
       # return variance_loss(K), covariance_loss(K)
        
def _mat_pow(mat, pow_, epsilon):
    # Computing matrix to the power of pow (pow can be negative as well)
    [D, V] = torch.linalg.eigh(mat)
    mat_pow = V @ torch.diag((D + epsilon).pow(pow_)) @ V.T
    mat_pow[mat_pow != mat_pow] = epsilon  # For stability
    return mat_pow

def min_nonzero_eig(Z, K,idx):
    # 确保Z和K在相同的设备和dtype上
    assert Z.device == K.device and Z.dtype == K.dtype, "Z and K should be on the same device and dtype"
    device = Z.device
    dtype = Z.dtype

    # 求中心化后的特征矩阵
    Z_centered = Z - Z.mean(dim=0)
    K_centered = K - K.mean(dim=0)
    
    # 计算协方差矩阵X和Y
    X = torch.mm(Z_centered.T, Z_centered) / (Z.size(0) - 1)
    Y = torch.mm(K_centered.T, K_centered) / (K.size(0) - 1)
    
    # 计算Y的特征值和特征向量
    I = torch.eye(X.size(0), device=device, dtype=dtype)
    eigenvalues, eigenvectors = torch.linalg.eigh(X.to(torch.float32))
    eigenvalues = torch.where(eigenvalues > 1e-5, eigenvalues, 1e-5*torch.ones_like(eigenvalues))
# 取实部，因为对于实对称矩阵，特征值是实数


        # 计算Y的逆平方根
    X_inv_sqrt = eigenvectors @ torch.diag(torch.pow(eigenvalues, -0.5)) @ eigenvectors.inverse()
    C = X_inv_sqrt@Y@X_inv_sqrt
    # 计算I - C的特征值
    
    target_matrix = I - C
    #pdb.set_trace()
    try:
        eigenvalues = torch.linalg.eigvalsh(target_matrix.to(torch.float32))
    except:
        pdb.set_trace()
    #pdb.set_trace()
    threshold = 1e-10
    #pdb.set_trace()
    # 找到第一个大于阈值的特征值，视为最小的非零特征值
    for j in range(len(eigenvalues)):
        if eigenvalues[j]>threshold:
                return eigenvalues[j],j,eigenvalues[min(idx,len(eigenvalues)-1)]
    #min_nonzero_eigenvalue = next((val for val in eigenvalues if abs(val) > threshold), None)
    #return min_nonzero_eigenvalue

def approximate_largest_singular_value(A, num_iterations=2):
    """
    近似计算矩阵A的最大奇异值。
    
    参数:
    A -- 要计算其奇异值的矩阵 (必须是2D张量)
    num_iterations -- 幂迭代的次数 (默认为2)
    
    返回:
    approximate_sigma_max -- 近似的最大奇异值
    """
    # 确保A是一个二维矩阵
    if A.ndim != 2:
        raise ValueError("A must be a 2D matrix.")
    
    Dim = A.size(1)
    
    # 初始化向量b_k为随机单位向量
    b_k = torch.randn(Dim, device=A.device, dtype=A.dtype)
    b_k = b_k / torch.norm(b_k)
    
    # 进行指定次数的幂迭代
    for _ in range(num_iterations):
        # 计算A^T * A * b_k
        Ab = torch.mv(A.T, torch.mv(A, b_k))
        
        # 归一化向量
        b_k = Ab / torch.norm(Ab)
    
    # 计算近似的最大奇异值
    approximate_sigma_max = torch.norm(torch.mv(A, b_k))
    
    return approximate_sigma_max


"""Function used for Orthogonal Regularization"""
def l2_reg_ortho_loss_func(mdl,device,weight = 1e-2,method='risp'):
        l2_reg = None
        if method == 'cor':
                 return weight*ortho(mdl,device)

        for W in mdl.parameters():
                if W.ndimension() < 2:
                        continue
                else:
                        cols = W[0].numel()
                        rows = W.shape[0]
                       

                        if method =='risp':
                                w1 = W.reshape(-1,cols)
                                wt = torch.transpose(w1,0,1).contiguous()
                                if (rows > cols):
                                        m  = torch.matmul(wt,w1)
                                        ident = Variable(torch.eye(cols,cols),requires_grad=True)
                                else:
                                        m = torch.matmul(w1,wt)
                                        ident = Variable(torch.eye(rows,rows), requires_grad=True)

                                ident = ident.to(device)
                                w_tmp = (m - ident)
                                b_k = Variable(torch.rand(w_tmp.shape[1],1))
                                b_k = b_k.to(device)

                                v1 = torch.matmul(w_tmp, b_k)
                                norm1 = torch.norm(v1,2)
                                v2 = torch.div(v1,norm1)
                                v3 = torch.matmul(w_tmp,v2)

                                if l2_reg is None:
                                        l2_reg = (torch.norm(v3,2))**2
                                else:
                                        l2_reg = l2_reg + (torch.norm(v3,2))**2
                        elif method =='so':
                                w1 = W.reshape(-1,cols)
                                wt = torch.transpose(w1,0,1).contiguous()
                                if (rows > cols):
                                        m  = torch.matmul(wt,w1)
                                        ident = Variable(torch.eye(cols,cols),requires_grad=True)
                                else:
                                        m = torch.matmul(w1,wt)
                                        ident = Variable(torch.eye(rows,rows), requires_grad=True)

                                ident = ident.to(device)
                                w_tmp = (m - ident)

                                if l2_reg is None:
                                        l2_reg = (torch.norm(w_tmp ,2))**2
                                else:
                                        l2_reg = l2_reg + (torch.norm(w_tmp,2))**2

                        elif 'LP' in method:
                                p = float(method.split('_')[1])
                                w1 = W.reshape(-1,cols)
                                wt = torch.transpose(w1,0,1).contiguous()
                                if (rows > cols):
                                        m  = torch.matmul(wt,w1)
                                        ident = Variable(torch.eye(cols,cols),requires_grad=True)
                                else:
                                        m = torch.matmul(w1,wt)
                                        ident = Variable(torch.eye(rows,rows), requires_grad=True)

                                ident = ident.to(device)
                                w_tmp = (m - ident)

                                if l2_reg is None:
                                        l2_reg = calculate_lp_norm(w_tmp,p)
                                else:
                                        l2_reg = l2_reg +  calculate_lp_norm(w_tmp,p)
                        elif method =='trace_norm':
                                w1 = W.reshape(-1,cols)
                                wt = torch.transpose(w1,0,1).contiguous()
                                if (rows > cols):
                                        m  = torch.matmul(wt,w1)
                                        ident = Variable(torch.eye(cols,cols),requires_grad=True)
                                else:
                                        m = torch.matmul(w1,wt)
                                        ident = Variable(torch.eye(rows,rows), requires_grad=True)

                                ident = ident.to(device)
                                A = (m - ident) 
                                # AtA = A.t() @ A
                                # trace_sum = torch.trace(AtA.sqrt())
                                trace_sum = torch.linalg.norm(A,'nuc')

                                #trace_norm = torch.sqrt(trace_sum)

                                if l2_reg is None:
                                        l2_reg = trace_sum
                                else:
                                        l2_reg = l2_reg + trace_sum
                        elif method =='cross_norm':
                                w1 = W.reshape(-1,cols)
                                if (rows > cols):
                                        w1_norm = (w1**2).sum(0, keepdim=True).sqrt()
                                        w1 = w1/w1_norm
                                        wt = torch.transpose(w1,0,1)
                                        m  = torch.matmul(wt,w1)
        
                                        #ident = Variable(torch.eye(cols,cols),requires_grad=True)
                                else:
                                        w1_norm = (w1**2).sum(1, keepdim=True).sqrt()
                                        w1 = w1/w1_norm
                                        wt = torch.transpose(w1,0,1)
                                        m = torch.matmul(w1,wt)
                                        #ident = Variable(torch.eye(rows,rows), requires_grad=True)
                                w_tmp = (torch.abs(m)/2) 
                                loss_fn = nn.CrossEntropyLoss()
                                labels = torch.arange(w_tmp.shape[0]).to(device)
                                #pdb.set_trace()
                                if l2_reg is None:
                                        l2_reg = loss_fn(w_tmp, labels)
                                else:
                                        l2_reg = l2_reg + loss_fn(w_tmp, labels)
                        elif method =='cross_norm_1':
                                w1 = W.reshape(-1,cols)
                                if (rows > cols):
                                        w1_norm = (w1**2).sum(0, keepdim=True).sqrt()
                                        w1 = w1/w1_norm
                                        wt = torch.transpose(w1,0,1)
                                        m  = torch.matmul(wt,w1)
        
                                        #ident = Variable(torch.eye(cols,cols),requires_grad=True)
                                else:
                                        w1_norm = (w1**2).sum(1, keepdim=True).sqrt()
                                        w1 = w1/w1_norm
                                        wt = torch.transpose(w1,0,1)
                                        m = torch.matmul(w1,wt)
                                        #ident = Variable(torch.eye(rows,rows), requires_grad=True)
                                w_tmp = m/2
                                loss_fn = nn.CrossEntropyLoss()
                                labels = torch.arange(w_tmp.shape[0]).to(device)
                                #pdb.set_trace()
                                if l2_reg is None:
                                        l2_reg = loss_fn(w_tmp, labels)
                                else:
                                        l2_reg = l2_reg + loss_fn(w_tmp, labels)
                        elif method =='cross_norm_dia':
                                w1 = W.reshape(-1,cols)
                                if (rows > cols):
                                        w1_norm = (w1**2).sum(0, keepdim=True).sqrt()
                                        w1 = w1/w1_norm
                                        wt = torch.transpose(w1,0,1)
                                        m  = torch.matmul(wt,w1)
        
                                        #ident = Variable(torch.eye(cols,cols),requires_grad=True)
                                else:
                                        w1_norm = (w1**2).sum(1, keepdim=True).sqrt()
                                        w1 = w1/w1_norm
                                        wt = torch.transpose(w1,0,1)
                                        m = torch.matmul(w1,wt)
                                        #ident = Variable(torch.eye(rows,rows), requires_grad=True)
                                w_tmp = m/2 
                                loss_fn = nn.CrossEntropyLoss()
                                labels = torch.arange(w_tmp.shape[0]).to(device)
                                #pdb.set_trace()
                                if l2_reg is None:
                                        l2_reg = loss_fn(w_tmp, labels) + torch.norm((w1_norm**2-Variable(torch.ones(w1_norm.shape).to(device),requires_grad=True)),2)
                                else:
                                        l2_reg = l2_reg + loss_fn(w_tmp, labels) + torch.norm((w1_norm**2-Variable(torch.ones(w1_norm.shape).to(device),requires_grad=True)),2)
                        elif method =='diag':
                                w1 = W.reshape(-1,cols)
                                if (rows > cols):
                                        w1_norm = (w1**2).sum(0, keepdim=True).sqrt()
                                        #w1 = w1/w1_norm
                                        #wt = torch.transpose(w1,0,1)
                                        #m  = torch.matmul(wt,w1)
        
                                        #ident = Variable(torch.eye(cols,cols),requires_grad=True)
                                else:
                                        w1_norm = (w1**2).sum(1, keepdim=True).sqrt()
                                        #w1 = w1/w1_norm
                                        #wt = torch.transpose(w1,0,1)
                                        #m = torch.matmul(w1,wt)
                                        #ident = Variable(torch.eye(rows,rows), requires_grad=True)
                                #w_tmp = m/2 
                                #loss_fn = nn.CrossEntropyLoss()
                                #labels = torch.arange(w_tmp.shape[0]).to(device)
                                #pdb.set_trace()
                                if l2_reg is None:
                                        l2_reg =  torch.norm((w1_norm**2-Variable(torch.ones(w1_norm.shape).to(device),requires_grad=True)),2)
                                else:
                                        l2_reg = l2_reg +  torch.norm((w1_norm**2-Variable(torch.ones(w1_norm.shape).to(device),requires_grad=True)),2)

                        elif method =='strict':
                                w1 = W.reshape(-1,cols)
                                if (rows > cols):
                                        w1_norm = (w1**2).sum(0, keepdim=True).sqrt()
                                        w1 = w1/w1_norm
                                        wt = torch.transpose(w1,0,1)
                                        m  = torch.matmul(wt,w1)
                                        ident = Variable(torch.eye(cols,cols),requires_grad=True)
        
                                        #ident = Variable(torch.eye(cols,cols),requires_grad=True)
                                else:
                                        w1_norm = (w1**2).sum(1, keepdim=True).sqrt()
                                        w1 = w1/w1_norm
                                        wt = torch.transpose(w1,0,1)
                                        m = torch.matmul(w1,wt)
                                        ident = Variable(torch.eye(rows,rows), requires_grad=True)
                                        #ident = Variable(torch.eye(rows,rows), requires_grad=True)

                                ident = ident.to(device)
                                #pdb.set_trace()
                                if l2_reg is None:
                                        l2_reg =  torch.norm((m-ident),2) + torch.norm((w1_norm**2-Variable(torch.ones(w1_norm.shape).to(device),requires_grad=True)),2)
                                else:
                                        l2_reg = l2_reg+ torch.norm((m-ident),2) + torch.norm((w1_norm**2-Variable(torch.ones(w1_norm.shape).to(device),requires_grad=True)),2)
                        elif method =='spectral_norm_min':
                                w1 = W.reshape(-1,cols)
                                wt = torch.transpose(w1,0,1).contiguous()
                                if (rows > cols):
                                        m  = torch.matmul(wt,w1)
                                        #ident = Variable(torch.eye(cols,cols),requires_grad=True)
                                else:
                                        m = torch.matmul(w1,wt)
                                        #ident = Variable(torch.eye(rows,rows), requires_grad=True)

                                #ident = ident.to(device)
                                w_tmp = m 
                                b_k = Variable(torch.rand(w_tmp.shape[1],1))
                                b_k = b_k.to(device)

                                v1 = torch.matmul(w_tmp, b_k)
                                norm1 = torch.norm(v1,2)
                                v2 = torch.div(v1,norm1)
                                v3 = torch.matmul(w_tmp.t(),v2)

                                if l2_reg is None:
                                        l2_reg = (torch.norm(v3,2))**2
                                else:
                                        l2_reg = l2_reg + (torch.norm(v3,2))**2
                        
                        elif method == 'ssv':
                                w1 = W.reshape(-1,cols)
                                wt = torch.transpose(w1,0,1).contiguous()
                                if (rows > cols):
                                        m  = torch.matmul(wt,w1)
                                        ident = Variable(torch.eye(cols,cols),requires_grad=True)
                                else:
                                        m = torch.matmul(w1,wt)
                                        ident = Variable(torch.eye(rows,rows), requires_grad=True)

                                ident = ident.to(device)
                                A = (m - ident) 
                               
                                v = torch.linalg.norm(A,-2)

                                if l2_reg is None:
                                        l2_reg = v
                                else:
                                        l2_reg = l2_reg + v
                        elif method== 'decov':
                                w1 = W.reshape(-1,cols)
                                N, D = w1.size()
    
                                cov_z1 = torch.corrcoef(w1)
    #return torch.trace(cov_z1)/math.sqrt(D*D)
                                diag = torch.eye(N, device=device)
                                #pdb.set_trace()
                                
                                cov_loss = torch.norm(cov_z1[~diag.bool()])**2
                                if l2_reg is None:
                                        l2_reg = cov_loss
                                else:
                                        l2_reg = l2_reg + cov_loss
                        elif method =='mma':
                                w1 = W.reshape(-1,cols)
                                w1_ = F.normalize(w1, p=2, dim=1)
                                cosine = torch.matmul(w1_, w1_.t())

                                cosine = cosine - 2. * torch.diag(torch.diag(cosine))

    
                                loss = -torch.acos(cosine.max(dim=1)[0].clamp(-0.99999, 0.99999)).mean()

                                if l2_reg is None:
                                        l2_reg = loss
                                else:
                                        l2_reg = l2_reg + loss
                        else:
                                print('wrong method')
                               
                                

        return weight*l2_reg

def ortho(mdl, device):
    ortho_penalty = []
    cnt = 0
    for m in mdl.modules():
        if isinstance(m, nn.Conv2d):
            if m.kernel_size == (7, 7) or m.weight.shape[1] == 3:
                continue
            o = ortho_conv(m, device)
            cnt += 1
            ortho_penalty.append(o)
    ortho_penalty = sum(ortho_penalty)
    return ortho_penalty

def ortho_conv(m, device):
    operator = m.weight
    operand = torch.cat(torch.chunk(m.weight, m.groups, dim=0), dim=1)
    transposed = m.weight.shape[1] < m.weight.shape[0]
    num_channels = m.weight.shape[1] if transposed else m.weight.shape[0]
    if transposed:
        operand = operand.transpose(1, 0)
        operator = operator.transpose(1, 0)
    gram = F.conv2d(operand, operator, padding=(m.kernel_size[0] - 1, m.kernel_size[1] - 1),
                    stride=m.stride, groups=m.groups)
    identity = torch.zeros(gram.shape).to(device)
    identity[:, :, identity.shape[2] // 2, identity.shape[3] // 2] = torch.eye(num_channels).repeat(1, m.groups)
    out = torch.sum((gram - identity) ** 2.0) / 2.0
    return out