import torch

from tools import mapto_SPD_cone


jitter=1e-7


def ddc_mmd(data, nb_src_samples):
    """
     Computes the Maximum Mean Discrepancy (MMD) between two sets of feature representations.
     This implements Equation 1 from "Deep Domain Confusion: Maximizing for Domain Invariance" paper.

     Args:
         data (torch.Tensor): Combined batch of shape (N, D), where N is total samples and D is the feature dimension.
         nb_src_samples (int): Number of source samples in the batch.

     Returns:
         torch.Tensor: Scalar tensor representing the MMD distance.
     """
    # Split data into source and target
    phi_source = data[:nb_src_samples]
    phi_target = data[nb_src_samples:]

    # Compute means
    mean_source = phi_source.mean(dim=0)
    mean_target = phi_target.mean(dim=0)

    diff = mean_source - mean_target

    mmd_squared = torch.sum(diff ** 2)  # Equivalent to ||mean_source - mean_target||^2

    return mmd_squared


def coral_loss(batch, nb_src_samples):
    H_s = batch[:nb_src_samples]
    H_t = batch[nb_src_samples:]

    d = H_s.size(1)

    cov_s = torch.cov(H_s.T)
    cov_t = torch.cov(H_t.T)

    loss = torch.sum((cov_s - cov_t) ** 2)

    return loss


def minimal_entropy_correlation_alignment(data, nb_src_samples):
    """
    Compute the Minimal-Entropy Correlation Alignment (log-CORAL) loss between source and target domains.
    This function is adapted from the official implementation at https://github.com/pmorerio/minimal-entropy-correlation-alignment/tree/master.

    Args:
        data (torch.Tensor): Combined batch of shape (N, D), where N is total samples and D is the feature dimension.
        nb_src_samples (int): Number of source samples in the batch.

    Returns:
        torch.Tensor: Log-CORAL loss value.
    """
    # Split data into source and target
    h_src = data[:nb_src_samples]
    h_trg = data[nb_src_samples:]

    batch_size = float(h_src.size(0))

    # Subtract the mean from the data matrix (center the data)
    h_src = h_src - torch.mean(h_src, dim=0)
    h_trg = h_trg - torch.mean(h_trg, dim=0)

    # Compute covariance matrices
    cov_source = (1.0 / (batch_size - 1)) * torch.mm(h_src.t(), h_src)
    cov_target = (1.0 / (batch_size - 1)) * torch.mm(h_trg.t(), h_trg)

    # Eigenvalue decomposition
    eig_vals_source, eig_vecs_source = torch.linalg.eigh(cov_source)
    eig_vals_target, eig_vecs_target = torch.linalg.eigh(cov_target)

    # Ensure eigenvalues are positive (add small epsilon to avoid log(0))
    eig_vals_source = torch.clamp(eig_vals_source, min=1e-12)
    eig_vals_target = torch.clamp(eig_vals_target, min=1e-12)

    # Compute logarithm of covariance matrices using eigendecomposition
    # log(C) = V * diag(log(λ)) * V^T
    log_cov_source = torch.mm(eig_vecs_source,
                              torch.mm(torch.diag(torch.log(eig_vals_source)),
                                       eig_vecs_source.t()))

    log_cov_target = torch.mm(eig_vecs_target,
                              torch.mm(torch.diag(torch.log(eig_vals_target)),
                                       eig_vecs_target.t()))

    # Compute Frobenius norm of the difference and return mean
    return torch.mean(torch.square(log_cov_source - log_cov_target))


def central_moment_discrepancy(data, nb_src_samples, n_moments=2):
    """
    Compute Central Moment Discrepancy (CMD) between source and target domains.
    This function is adapted from the officially released code at https://github.com/wzell/cmd/tree/master.

    Args:
        data (torch.Tensor): Combined batch of shape (N, D), where N is total samples and D is the feature dimension.
        nb_src_samples (int): Number of source samples in the batch.
        n_moments (int): Number of moments to match (default: 2).

    Returns:
        torch.Tensor: CMD loss value.
    """

    def matchnorm(x1, x2):
        """Compute L2 norm of difference between two tensors."""

        return torch.sqrt(torch.clamp(torch.sum((x1 - x2) ** 2), min=jitter))

    def scm(sx1, sx2, k):
        """Compute k-th central moment matching."""
        # Compute k-th power of centered data and take mean
        ss1 = torch.mean(sx1 ** k, dim=0)
        ss2 = torch.mean(sx2 ** k, dim=0)
        return matchnorm(ss1, ss2)

    def mmatch(x1, x2, n_moments):
        """Main moment matching function."""
        # Compute means
        mx1 = torch.mean(x1, dim=0)
        mx2 = torch.mean(x2, dim=0)

        # Center the data
        sx1 = x1 - mx1
        sx2 = x2 - mx2

        # Match first moment (means)
        dm = matchnorm(mx1, mx2)
        scms = dm

        # Match higher order central moments
        for i in range(n_moments - 1):
            scms = scms + scm(sx1, sx2, i + 2)

        return scms

    # Split data into source and target
    source_data = data[:nb_src_samples]
    target_data = data[nb_src_samples:]

    # Compute CMD loss
    cmd_loss = mmatch(source_data, target_data, n_moments)

    return cmd_loss


def higher_order_moment_matching(data, nb_src_samples, order=2):
    """
    Compute Higher-order Moment Matching (HoMM) loss using an adapted version of the exact
    implementation style from the official paper code for orders 2 and 3 at https://github.com/chenchao666/HoMM-Master/tree/master.

    Args:
        data (torch.Tensor): Combined batch of shape (N, D), where N is total samples and D is the feature dimension.
        nb_src_samples (int): Number of source samples in the batch.
        order (int): Order of moments to match (2 or 3).

    Returns:
        torch.Tensor: HoMM loss value.
    """
    # Split data into source and target
    h_src = data[:nb_src_samples]
    h_trg = data[nb_src_samples:]

    # Center the data by subtracting the mean
    h_src = h_src - torch.mean(h_src, dim=0)
    h_trg = h_trg - torch.mean(h_trg, dim=0)

    if order == 2:
        # For order 2: create 3D tensors
        # Original shape: (batch_size, d) -> (batch_size, d, 1)
        xs = h_src.unsqueeze(-1)  # Shape: (batch_size, d, 1)
        xt = h_trg.unsqueeze(-1)  # Shape: (batch_size, d, 1)

        # Create the transposed version for broadcasting
        xs_1 = xs.transpose(1, 2)  # Shape: (batch_size, 1, d)
        xt_1 = xt.transpose(1, 2)  # Shape: (batch_size, 1, d)

        # Compute element-wise product (broadcasting creates all combinations)
        HR_Xs = xs * xs_1  # Shape: (batch_size, d, d)
        HR_Xs = torch.mean(HR_Xs, dim=0)  # Shape: (d, d)

        HR_Xt = xt * xt_1  # Shape: (batch_size, d, d)
        HR_Xt = torch.mean(HR_Xt, dim=0)  # Shape: (d, d)

    elif order == 3:
        # For order 3: create 4D tensors
        # Original shape: (batch_size, d) -> (batch_size, d, 1, 1)
        xs = h_src.unsqueeze(-1).unsqueeze(-1)  # Shape: (batch_size, d, 1, 1)
        xt = h_trg.unsqueeze(-1).unsqueeze(-1)  # Shape: (batch_size, d, 1, 1)

        # Create transposed versions for broadcasting
        xs_1 = xs.transpose(1, 2)  # Shape: (batch_size, 1, d, 1) - equivalent to [0,2,1,3]
        xs_2 = xs.transpose(1, 3)  # Shape: (batch_size, 1, 1, d) - equivalent to [0,2,3,1]

        xt_1 = xt.transpose(1, 2)  # Shape: (batch_size, 1, d, 1)
        xt_2 = xt.transpose(1, 3)  # Shape: (batch_size, 1, 1, d)

        # Compute element-wise product (broadcasting creates all combinations)
        HR_Xs = xs * xs_1 * xs_2  # Shape: (batch_size, d, d, d)
        HR_Xs = torch.mean(HR_Xs, dim=0)  # Shape: (d, d, d)

        HR_Xt = xt * xt_1 * xt_2  # Shape: (batch_size, d, d, d)
        HR_Xt = torch.mean(HR_Xt, dim=0)  # Shape: (d, d, d)

    else:
        raise ValueError(f"Order {order} not supported. Use order 2 or 3.")

    # Compute final loss
    return torch.mean((HR_Xs - HR_Xt) ** 2)


def geo_adapt(feat_batch, nb_src_samples, metric_type='hilbert'):

    Z_s = feat_batch[:nb_src_samples]
    Z_t = feat_batch[nb_src_samples:]

    d = Z_s.size(1)

    mu_s = Z_s.mean(dim=0)
    mu_t = Z_t.mean(dim=0)

    cov_s = torch.cov(Z_s.T)
    cov_s = (cov_s + cov_s.T) / 2.0
    cov_s = cov_s + torch.eye(cov_s.size(0)).to(cov_s.device) * jitter

    cov_t = torch.cov(Z_t.T)
    cov_t = (cov_t + cov_t.T) / 2.0
    cov_t = cov_t + torch.eye(cov_t.size(0)).to(cov_t.device) * jitter

    Ps = mapto_SPD_cone(cov_s, mu_s, beta=1.0)
    Ps_eigvals = torch.linalg.eigvalsh(Ps).real
    Ps_det = torch.prod(Ps_eigvals)

    Pt= mapto_SPD_cone(cov_t, mu_t, beta=1.0)

    # Pt_eigvals = torch.linalg.eigvalsh(Pt).real
    # print('Ps -- det: {:.2e}, max_eig: {:.4f}, min_eig: {:.4f}, k: {:.2e}'.format(Ps_det,
    #                                                                               Ps_eigvals.max(), Ps_eigvals.min(),
    #                                                                               Ps_eigvals.max()/Ps_eigvals.min())
    #       )
    # print('Pt -- det: {:.2e}, max_eig: {:.4f}, min_eig: {:.4f}, k: {:.2e}'.format(torch.prod(Pt_eigvals),
    #                                                                               Pt_eigvals.max(), Pt_eigvals.min(),
    #                                                                               Pt_eigvals.max()/Pt_eigvals.min())
    #       )
    # print('-------------------------------------------------------------------------------')

    if metric_type == 'hilbert':
        L0 = torch.linalg.cholesky(Ps)
        eigenvalues = torch.linalg.eigvals(torch.cholesky_solve(Pt, L0)).real

        max_eigenvalue = torch.max(eigenvalues)
        min_eigenvalue = torch.min(eigenvalues)

        loss = torch.log(max_eigenvalue / min_eigenvalue)

        # print('Hilbert -- max_eig: {:.4f}, min_eig: {:.4f}, k: {:.2e}'.format(max_eigenvalue, min_eigenvalue,
        #                                                                       max_eigenvalue/min_eigenvalue)
        #       )

    elif metric_type == 'log_euclidean':
        # Eigenvalue decomposition
        eig_vals_source, eig_vecs_source = torch.linalg.eigh(Ps)
        eig_vals_target, eig_vecs_target = torch.linalg.eigh(Pt)

        # Ensure eigenvalues are positive (add small epsilon to avoid log(0))
        eig_vals_source = torch.clamp(eig_vals_source, min=1e-12)
        eig_vals_target = torch.clamp(eig_vals_target, min=1e-12)

        # Compute logarithm of covariance matrices using eigendecomposition
        # log(C) = V * diag(log(λ)) * V^T
        log_cov_source = torch.mm(eig_vecs_source, torch.mm(torch.diag(torch.log(eig_vals_source)), eig_vecs_source.t()))

        log_cov_target = torch.mm(eig_vecs_target, torch.mm(torch.diag(torch.log(eig_vals_target)), eig_vecs_target.t()))

        loss = torch.mean(torch.square(log_cov_source - log_cov_target))

    elif metric_type == 'airm':
        L0 = torch.linalg.cholesky(Ps)
        eigenvalues = torch.linalg.eigvals(torch.cholesky_solve(Pt, L0)).real

        loss = torch.sqrt(torch.sum(torch.log(eigenvalues)**2))


    else:
        ValueError('The requested distance metric type is not implemented!')

    return loss, Ps_det


adapt_loss_functions = {
    'ddc': ddc_mmd,
    'coral': coral_loss,
    'log_coral': minimal_entropy_correlation_alignment,
    'cmd': central_moment_discrepancy,
    'homm': higher_order_moment_matching,
    'geo_adapt': geo_adapt
}


'''
top-level wrapper
'''


def compute_batch_loss(batch, P):
    assert batch['logits'].dim() == 2

    if P['exp_mode'] == 'adaptation':
        class_loss = torch.nn.functional.cross_entropy(batch['logits'][:batch['num_src_samples']], batch['labels'])

        if P['adapt_method'] == 'geo_adapt':
            da_loss, batch['Ps_det'] = adapt_loss_functions[P['adapt_method']](batch['latent_feat'], batch['num_src_samples'],
                                                                               metric_type=P['dist_metric_type'])

            if batch['Ps_det'] >= P['det_thr']:
                batch['lamda'] = P['user_lamda']
            else:
                batch['lamda'] = 0.0

        elif P['adapt_method'] in ['cmd', 'homm']:
            da_loss = adapt_loss_functions[P['adapt_method']](batch['latent_feat'], batch['num_src_samples'], P['highest_moment'])

            batch['lamda'] = P['user_lamda']

        else:
            da_loss = adapt_loss_functions[P['adapt_method']](batch['latent_feat'], batch['num_src_samples'])

            batch['lamda'] = P['user_lamda']

        batch['loss_tensor'] = class_loss + (batch['lamda'] * da_loss)

        batch['da_loss_np'] = da_loss.clone().detach().cpu().numpy()

    else:
        class_loss = torch.nn.functional.cross_entropy(batch['logits'], batch['labels'])

        batch['loss_tensor'] = class_loss

        batch['da_loss_np'] = None


    batch['cs_loss_np'] = class_loss.clone().detach().cpu().numpy()

    # batch['loss_np'] = batch['loss_tensor'].clone().detach().cpu().numpy()


    return batch