import torch
import torch.nn as nn
import numpy as np
from torchvision import datasets, transforms
import time
from tqdm import tqdm
import os

import torch.nn.functional as F

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


def calCosCorr(f, g):              #HGRscore3:UniFast HGR
    Number2samples = len(f)

    f1 = torch.nn.functional.normalize(f, dim=1)
    g1 = torch.nn.functional.normalize(g, dim=1)

    corr = torch.sum(torch.sum(f1 * g1, 1)) / Number2samples
    return corr, f1, g1

def caldistributionCosCorr(f, g):
    Number2samples = len(f)
    f1 = torch.triu(f, diagonal=1)
    f1 = f1[0:Number2samples-1, 1:Number2samples]
    f1 = f1 + torch.triu(f1, diagonal=1).transpose(0, 1)

    g1 = torch.triu(g, diagonal=1)
    g1 = g1[0:Number2samples - 1, 1:Number2samples]
    g1 = g1 + torch.triu(g1, diagonal=1).transpose(0, 1)

    f1 = torch.nn.functional.normalize(f1, dim=1)
    g1 = torch.nn.functional.normalize(g1, dim=1)

    corr = torch.sum(torch.sum(f1 * g1, 1)) / (Number2samples - 1)
    return corr

def UniFastHGR(f, g):               #HGRscore3:UniFast HGR

    corr, f1, g1 = calCosCorr(f, g)

    distribution_f = torch.mm(f1,torch.t(f1))
    distribution_g = torch.mm(g1,torch.t(g1))

    tra = caldistributionCosCorr(distribution_f,distribution_g)

    result = torch.tensor(1.5) - corr - tra/2
    return result, corr, tra

def calCosCorr4(f, g):                       #HGRscore4:OptFast HGR
    Number2samples4 = len(f)

    corr = torch.sum(torch.sum(f * g, 1)) / Number2samples4
    return corr

def caldistributionCosCorr4(f, g):
    Number2samples4 = len(f)
    f1 = torch.triu(f, diagonal=1)

    corr = torch.sum(torch.sum(f1 * g, 1)) * 2 / (Number2samples4 - 1)
    return corr

def calHGR3bias(batch_size, Element_length, Random_times=1000):
    HGR = []
    for i in range(Random_times):
        f= 2 * torch.rand(batch_size, Element_length) - 1
        hgr, corr, tra = HGRscore4(f,f)
        HGR.append(hgr)
    return sum(HGR)/Random_times


def OptFastHGR(f, g, bias=0):               #HGRscore4:OptFast HGR
    f = torch.nn.functional.normalize(f, dim=1)
    g = torch.nn.functional.normalize(g, dim=1)

    corr = calCosCorr4(f, g)

    distribution_f = torch.mm(f,torch.t(f))
    distribution_g = torch.mm(g,torch.t(g))

    tra = caldistributionCosCorr4(distribution_f,distribution_g)

    result = (torch.tensor(1.5) - corr - tra - bias)/(1.5 - bias) * 1.5
    return result, corr, tra



def SoftHGR(f, g):
    # def HGRscore(f, g):                    #HGRscore2
    ba_s = len(f)
    sa_len = len(f.T)
    f0 = f.T
    g0 = g.T
    f1 = f0 - torch.mean(f0, 0)
    g1 = g0 - torch.mean(g0, 0)

    v_f2 = torch.var(f1, dim=0)
    f2 = f1 / (torch.sqrt(v_f2) )
    v_g2 = torch.var(g1, dim=0)
    g2 = g1 / (torch.sqrt(v_g2) )

    corr = torch.sum(torch.sum(f2 * g2, 1)) / (ba_s * sa_len)

    cov_f = torch.mm(torch.t(f1), f1) / (sa_len - 1)
    cov_g = torch.mm(torch.t(g1), g1) / (sa_len - 1)

    cov_f2 = cov_f - torch.mean(cov_f, 0)
    cov_g2 = cov_g - torch.mean(cov_g, 0)

    v_cov_f2 = torch.var(cov_f2, dim=0)
    cov_f2 = cov_f2 / (torch.sqrt(v_cov_f2))
    v_cov_g2 = torch.var(cov_g2, dim=0)
    cov_g2 = cov_g2 / (torch.sqrt(v_cov_g2))


    tra = torch.trace(torch.mm(cov_f2, cov_g2)) / 2 / ba_s / ba_s

    result = 1.5 - corr - tra

    return result, corr, tra



class DeepCCA(nn.Module):                 #DeepCCA
    def __init__(self, input_dim1, input_dim2, hidden_dim, num_layers):
        super(DeepCCA, self).__init__()
        self.layers1 = nn.ModuleList([nn.Linear(input_dim1, hidden_dim)])
        for _ in range(num_layers - 1):
            self.layers1.append(nn.Linear(hidden_dim, hidden_dim))
        self.layers2 = nn.ModuleList([nn.Linear(input_dim2, hidden_dim)])
        for _ in range(num_layers - 1):
            self.layers2.append(nn.Linear(hidden_dim, hidden_dim))

    def forward(self, x1, x2):
        for layer in self.layers1:
            x1 = torch.relu(layer(x1))
        for layer in self.layers2:
            x2 = torch.relu(layer(x2))
        return x1, x2
#
# def HGRscore(x1, x2):
def  DeepCCAscore(x1, x2):
    results = []
    for row1, row2 in zip(x1, x2):
        result = torch.mean(torch.dot(row1, row2)) / (torch.norm(row1) * torch.norm(row2))
        results.append(result)
    corr = result
    tra = result
    result = torch.mean(torch.tensor(results))
    return result, corr, tra

class SoftCCA(nn.Module):
    def __init__(self, input_dim1, input_dim2, hidden_dim, num_layers, lambda_=1.0):
        super(SoftCCA, self).__init__()
        self.layers1 = nn.ModuleList([nn.Linear(input_dim1, hidden_dim)])
        for _ in range(num_layers - 1):
            self.layers1.append(nn.Linear(hidden_dim, hidden_dim))
        self.layers2 = nn.ModuleList([nn.Linear(input_dim2, hidden_dim)])
        for _ in range(num_layers - 1):
            self.layers2.append(nn.Linear(hidden_dim, hidden_dim))
        self.lambda_ = lambda_

    def forward(self, x1, x2):
        for layer in self.layers1:
            x1 = torch.relu(layer(x1))
        for layer in self.layers2:
            x2 = torch.relu(layer(x2))
        return x1, x2

    def stochastic_decorrelation_loss(self, z1, z2):
        m = z1.shape[0]  # mini-batch size
        k = z1.shape[1]  # number of neurons/feature channels

        C_mini_z1 = (1 / (m - 1)) * torch.matmul(z1.T, z1)
        C_mini_z2 = (1 / (m - 1)) * torch.matmul(z2.T, z2)
        
        C_accu_z1 = torch.zeros_like(C_mini_z1)
        C_accu_z2 = torch.zeros_like(C_mini_z2)
        c_z1 = 0
        c_z2 = 0
       
        for i in range(m):            
            C_accu_z1 = torch.add(C_accu_z1, C_mini_z1)
            C_accu_z2 = torch.add(C_accu_z2, C_mini_z2)
            c_z1 += 1
            c_z2 += 1
           
        C_appx_z1 = C_accu_z1 / c_z1
        C_appx_z2 = C_accu_z2 / c_z2

        # Stochastic Decorrelation Loss (SDL)
        sdl_loss_z1 = 0
        sdl_loss_z2 = 0
        for i in range(k):
            for j in range(k):
                if i!= j:
                    sdl_loss_z1 += torch.abs(C_appx_z1[i][j])
                    sdl_loss_z2 += torch.abs(C_appx_z2[i][j])

        return sdl_loss_z1 + sdl_loss_z2

def SoftCCAscore(x1, x2):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    x1 = x1.clone().detach().float().to(device)
    x2 = x2.clone().detach().float().to(device)
  
    model = SoftCCA(input_dim1=x1.shape[1], input_dim2=x2.shape[1], hidden_dim=64, num_layers=3, lambda_=0.1)
    model.to(device)
   
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
   
    for epoch in range(100):
        optimizer.zero_grad()
        z1, z2 = model(x1, x2)
    
    z1_flat = z1.view(-1)
    z2_flat = z2.view(-1)
    corr = torch.mean(torch.dot(z1_flat, z2_flat)) / (torch.norm(z1_flat) * torch.norm(z2_flat))
    tra = corr
    result = corr   
    return result, corr, tra

class CCA:
    def __init__(self, n_components):
        self.n_components = n_components

    def fit(self, X, Y):

        assert X.shape[0] == Y.shape[0], "X and Y must have the same sample size"

        mean_X = np.mean(X, axis=0)
        mean_Y = np.mean(Y, axis=0)

        X_centered = X - mean_X
        Y_centered = Y - mean_Y


        cov_XX = np.cov(X_centered, rowvar=False) + np.eye(X.shape[1]) * 0.001
        cov_YY = np.cov(Y_centered, rowvar=False) + np.eye(Y.shape[1]) * 0.001
        cov_XY = np.cov(X_centered, Y_centered, rowvar=False)[:X.shape[1], X.shape[1]:]

        T = np.linalg.inv(np.linalg.cholesky(cov_XX)) @ cov_XY @ np.linalg.inv(np.linalg.cholesky(cov_YY)) @ cov_XY.T

        eigen_values, eigen_vectors = np.linalg.eig(T)

        sorted_indices = np.argsort(eigen_values)[::-1]
        self.canonical_correlations_ = np.sqrt(eigen_values[sorted_indices[:self.n_components]])
        self.X_loadings_ = eigen_vectors[:, sorted_indices[:self.n_components]]
        self.Y_loadings_ = (np.linalg.inv(cov_YY) @ cov_XY.T @ self.X_loadings_)

    def transform(self, X, Y):
        X_projected = X @ self.X_loadings_
        Y_projected = Y @ self.Y_loadings_
        return X_projected, Y_projected

def CCAscore(x1, x2):

    x1_np = x1.clone().detach().cpu().numpy()
    x2_np = x2.clone().detach().cpu().numpy()

    cca = CCA(n_components=1)
    cca.fit(x1_np, x2_np)

    x1_projected, x2_projected = cca.transform(x1_np, x2_np)
    corr = np.corrcoef(x1_projected[:, 0], x2_projected[:, 0])[0, 1]
    tra = corr
    result = corr
    return result, corr, tra




def SwiftHGR(f, g, rank=10):
    """
    Optimized UniFast HGR implementation with low-rank approximation
    :param f: feature matrix 1 (m x n)
    :param g: feature matrix 2 (m x n)
    :param rank: rank for low-rank approximation (default=10)
    :return: HGR objective, cosine correlation, trace term
    """
    m = f.size(0)
    eps = 1e-8

    # 1. Feature normalization
    f_norm = F.normalize(f, p=2, dim=1)
    g_norm = F.normalize(g, p=2, dim=1)

    # 2. Cosine correlation
    corr = torch.sum(f_norm * g_norm) / m

    # 3. Low-rank distribution matrix approximation
    # Random projection
    proj_matrix = torch.randn(f.size(1), rank, device=f.device)

    # Project features
    f_proj = f_norm @ proj_matrix
    g_proj = g_norm @ proj_matrix

    # Low-rank approximation of distribution matrix
    distri_f = f_proj @ f_proj.t()
    distri_g = g_proj @ g_proj.t()

    # 4. Efficient triu dot product using matrix identity
    full_dot = torch.sum(distri_f * distri_g)
    diag_dot = torch.sum(torch.diag(distri_f) * torch.diag(distri_g))
    tri_dot = (full_dot - diag_dot) / 2

    # 5. Normalized trace
    norm_f = torch.norm(distri_f, p='fro')
    norm_g = torch.norm(distri_g, p='fro')
    tr = tri_dot / (norm_f * norm_g + eps)

    # 6. SwiftHGR objective
    hgr = corr - 0.5 * tr

    return hgr, corr, tr


def AdaptiveSwiftHGR(f, g, energy_threshold=0.95, max_rank=50):
    """
    Enhanced SwiftHGR with adaptive rank selection and hardware optimization
    :param f: feature matrix 1 (m x n)
    :param g: feature matrix 2 (m x n)
    :param energy_threshold: threshold for adaptive rank selection (default=0.95)
    :param max_rank: maximum rank for approximation (default=50)
    :return: HGR objective, cosine correlation, trace term
    """
    m, n = f.size(0), f.size(1)
    eps = 1e-8

    # 1. Hardware-aware configuration
    use_mixed_precision = torch.cuda.is_available()
    is_sparse = (f.is_sparse or g.is_sparse or
                 (torch.sum(f != 0) / (m * n) < 0.3 or
                  torch.sum(g != 0) / (m * n) < 0.3))

    # 2. Feature normalization
    if is_sparse:
        # Sparse normalization
        f_norm = f / (torch.norm(f, p=2, dim=1, keepdim=True) + eps)
        g_norm = g / (torch.norm(g, p=2, dim=1, keepdim=True) + eps)
    else:
        # Dense normalization
        f_norm = F.normalize(f, p=2, dim=1)
        g_norm = F.normalize(g, p=2, dim=1)

    # 3. Cosine correlation (sparse-aware)
    if is_sparse:
        corr = torch.sum(f_norm * g_norm) / m
    else:
        if use_mixed_precision:
            with torch.cuda.amp.autocast():
                corr = torch.sum(f_norm * g_norm) / m
            corr = corr.float()
        else:
            corr = torch.sum(f_norm * g_norm) / m

    # 4. Adaptive rank selection
    def adaptive_rank(matrix, threshold, max_rank):
        """Adaptive rank selection based on energy threshold"""
        # Random projection for efficient estimation
        proj = matrix @ torch.randn(matrix.size(1), min(100, matrix.size(1)), device=matrix.device)
        _, S, _ = torch.svd(proj)
        energy = torch.cumsum(S ** 2, dim=0) / torch.sum(S ** 2)
        rank = torch.argmax(energy > threshold).item() + 1
        return min(max(rank, 5), max_rank)  # Minimum rank 5

    rank_f = adaptive_rank(f_norm, energy_threshold, max_rank)
    rank_g = adaptive_rank(g_norm, energy_threshold, max_rank)
    rank = max(rank_f, rank_g)

    # 5. Low-rank approximation (sparse-aware)
    def low_rank_approx(matrix, rank, sparse=False):
        """Low-rank approximation with sparse support"""
        if sparse:
            # Column sampling for sparse matrices
            indices = torch.randperm(matrix.size(1))[:2 * rank]
            C = matrix[:, indices]
            Q, _ = torch.linalg.qr(C)
            B = Q.t() @ matrix
            U, S, Vt = torch.svd(B)
            return Q @ U[:, :rank] @ torch.diag(S[:rank])
        else:
            # Random projection for dense matrices
            proj_matrix = torch.randn(matrix.size(1), rank, device=matrix.device)
            proj = matrix @ proj_matrix
            return proj @ proj.t()

    if is_sparse:
        distri_f = low_rank_approx(f_norm, rank, sparse=True)
        distri_g = low_rank_approx(g_norm, rank, sparse=True)
    else:
        if use_mixed_precision:
            with torch.cuda.amp.autocast():
                distri_f = low_rank_approx(f_norm, rank)
                distri_g = low_rank_approx(g_norm, rank)
            distri_f = distri_f.float()
            distri_g = distri_g.float()
        else:
            distri_f = low_rank_approx(f_norm, rank)
            distri_g = low_rank_approx(g_norm, rank)

    # 6. Efficient triu dot product
    if is_sparse:
        # Sparse-specific efficient calculation
        triu_f = torch.triu(distri_f, diagonal=1)
        triu_g = torch.triu(distri_g, diagonal=1)
        tri_dot = torch.sum(triu_f * triu_g)
    else:
        # Matrix identity for dense matrices
        full_dot = torch.sum(distri_f * distri_g)
        diag_dot = torch.sum(torch.diag(distri_f) * torch.diag(distri_g))
        tri_dot = (full_dot - diag_dot) / 2

    # 7. Normalized trace
    if use_mixed_precision and not is_sparse:
        with torch.cuda.amp.autocast():
            norm_f = torch.norm(distri_f, p='fro')
            norm_g = torch.norm(distri_g, p='fro')
        norm_f = norm_f.float()
        norm_g = norm_g.float()
    else:
        norm_f = torch.norm(distri_f, p='fro')
        norm_g = torch.norm(distri_g, p='fro')

    tr = tri_dot / (norm_f * norm_g + eps)

    # 8. AdaptiveSwiftHGR objective
    hgr = corr - 0.5 * tr

    return hgr, corr, tr
