"""
Gene-level loss functions for scBIG.

- Correlation loss: preserve gene-gene relationships
- Pathway OT loss: match distributions in pathway space
"""

import jax
import jax.numpy as jnp
from typing import Optional


def compute_correlation_matrix(X: jnp.ndarray, eps: float = 1e-8) -> jnp.ndarray:
    """Pearson correlation matrix."""
    X_centered = X - jnp.mean(X, axis=0, keepdims=True)
    N = X.shape[0]
    cov = jnp.dot(X_centered.T, X_centered) / (N - 1 + eps)
    std = jnp.sqrt(jnp.diag(cov) + eps)
    corr = cov / jnp.outer(std, std)
    corr = jnp.clip(corr, -1.0, 1.0)
    return corr


def compute_covariance_matrix(X: jnp.ndarray, eps: float = 1e-8) -> jnp.ndarray:
    """Covariance matrix."""
    X_centered = X - jnp.mean(X, axis=0, keepdims=True)
    N = X.shape[0]
    cov = jnp.dot(X_centered.T, X_centered) / (N - 1 + eps)
    return cov


def gene_correlation_loss(X_pred: jnp.ndarray,
                          X_true: jnp.ndarray,
                          use_correlation: bool = True) -> jnp.ndarray:
    """
    Relational consistency loss - keeps gene-gene correlation structure.
    Loss = ||R_pred - R_true||_F / n_genes
    """
    if use_correlation:
        mat_pred = compute_correlation_matrix(X_pred)
        mat_true = compute_correlation_matrix(X_true)
    else:
        mat_pred = compute_covariance_matrix(X_pred)
        mat_true = compute_covariance_matrix(X_true)
    
    diff = mat_pred - mat_true
    loss = jnp.sqrt(jnp.sum(diff ** 2))
    n_genes = X_pred.shape[1]
    loss = loss / n_genes
    
    return loss


def cluster_correlation_loss(X_pred: jnp.ndarray,
                             X_true: jnp.ndarray,
                             cluster_labels: jnp.ndarray,
                             n_clusters: int,
                             use_correlation: bool = True) -> jnp.ndarray:
    """
    Efficient version - aggregate to clusters first, then compute K x K correlation.
    O(N*K^2) instead of O(N*G^2).
    """
    Z_pred = _aggregate_by_cluster(X_pred, cluster_labels, n_clusters)
    Z_true = _aggregate_by_cluster(X_true, cluster_labels, n_clusters)
    
    if use_correlation:
        mat_pred = compute_correlation_matrix(Z_pred)
        mat_true = compute_correlation_matrix(Z_true)
    else:
        mat_pred = compute_covariance_matrix(Z_pred)
        mat_true = compute_covariance_matrix(Z_true)
    
    diff = mat_pred - mat_true
    loss = jnp.sqrt(jnp.sum(diff ** 2)) / n_clusters
    
    return loss


def _aggregate_by_cluster(X: jnp.ndarray,
                          cluster_labels: jnp.ndarray,
                          n_clusters: int) -> jnp.ndarray:
    """Aggregate gene expression to cluster level."""
    one_hot = jax.nn.one_hot(cluster_labels, n_clusters)
    cluster_sizes = jnp.maximum(jnp.sum(one_hot, axis=0), 1.0)
    Z = jnp.dot(X, one_hot) / cluster_sizes[None, :]
    return Z


def pathway_ot_loss(X_pred: jnp.ndarray,
                    X_true: jnp.ndarray,
                    pathway_matrix: jnp.ndarray,
                    epsilon: float = 0.1,
                    n_iterations: int = 50) -> jnp.ndarray:
    """
    Sinkhorn distance in pathway-aggregated space.
    """
    Z_pred = _pathway_aggregate(X_pred, pathway_matrix)
    Z_true = _pathway_aggregate(X_true, pathway_matrix)
    
    loss = _sinkhorn_distance(Z_pred, Z_true, epsilon, n_iterations)
    n_pathways = pathway_matrix.shape[0]
    loss = loss / n_pathways
    
    return loss


def _pathway_aggregate(X: jnp.ndarray,
                       pathway_matrix: jnp.ndarray,
                       normalize: bool = True) -> jnp.ndarray:
    """Project to pathway space."""
    Z = jnp.dot(X, pathway_matrix.T)
    if normalize:
        pathway_sizes = jnp.maximum(jnp.sum(pathway_matrix, axis=1), 1.0)
        Z = Z / pathway_sizes[None, :]
    return Z


def _sinkhorn_distance(X: jnp.ndarray,
                       Y: jnp.ndarray,
                       epsilon: float = 0.1,
                       n_iterations: int = 50) -> jnp.ndarray:
    """Entropy-regularized OT distance."""
    n = X.shape[0]
    m = Y.shape[0]
    
    # cost matrix
    X_sqnorm = jnp.sum(X ** 2, axis=1, keepdims=True)
    Y_sqnorm = jnp.sum(Y ** 2, axis=1, keepdims=True)
    C = X_sqnorm + Y_sqnorm.T - 2 * jnp.dot(X, Y.T)
    C = jnp.maximum(C, 0)
    
    K = jnp.exp(-C / epsilon)
    
    a = jnp.ones(n) / n
    b = jnp.ones(m) / m
    
    u = jnp.ones(n)
    v = jnp.ones(m)
    
    for _ in range(n_iterations):
        u = a / (jnp.dot(K, v) + 1e-10)
        v = b / (jnp.dot(K.T, u) + 1e-10)
    
    P = u[:, None] * K * v[None, :]
    distance = jnp.sum(P * C)
    
    return distance


def combined_gene_loss(X_pred: jnp.ndarray,
                       X_true: jnp.ndarray,
                       cluster_labels: Optional[jnp.ndarray] = None,
                       n_clusters: int = 32,
                       pathway_matrix: Optional[jnp.ndarray] = None,
                       lambda_corr: float = 1.0,
                       lambda_ot: float = 0.1) -> jnp.ndarray:
    """Combined loss: correlation + pathway OT."""
    total_loss = 0.0
    
    if cluster_labels is not None:
        corr_loss = cluster_correlation_loss(
            X_pred, X_true, cluster_labels, n_clusters
        )
    else:
        corr_loss = gene_correlation_loss(X_pred, X_true)
    total_loss += lambda_corr * corr_loss
    
    if pathway_matrix is not None and lambda_ot > 0:
        ot_loss = pathway_ot_loss(X_pred, X_true, pathway_matrix)
        total_loss += lambda_ot * ot_loss
    
    return total_loss
