"""
Ordinal Entropy regularizer
"""
import torch
import torch.nn.functional as F
import random

# def ordinalentropy_tig_ent(features, gt,  mask=None):
#     """
#     Features: a certain layer's features
#     gt: pixel-wise ground truth values, in depth estimation, gt.size()= n, h, w
#     mask: In case values of some pixels do not exist. For depth estimation, there are some pixels lack the ground truth values
#     """
#     f_n, f_c = features.size()

#     u_value, u_index, u_counts = torch.unique(gt, return_inverse=True, return_counts=True)
#     # center_f = torch.zeros([len(u_value), f_c]).cuda()
#     # for idx in range(len(u_value)):
#     #     center_f[idx, :] = torch.mean(_features[u_index==idx, :], dim=0)

#     center_f = torch.zeros([len(u_value), f_c]).cuda()
#     u_index = u_index.squeeze()
#     center_f.index_add_(0, u_index, features)
#     u_counts = u_counts.unsqueeze(1)
#     center_f = center_f / u_counts

#     p = F.normalize(center_f, dim=1)
#     _distance = euclidean_dist(p, p)
#     _distance = up_triu(_distance)

#     u_value = u_value.unsqueeze(1)
#     _weight = euclidean_dist(u_value, u_value)
#     _weight = up_triu(_weight)
#     _max = torch.max(_weight)
#     _min = torch.min(_weight)
#     _weight = ((_weight - _min) / (_max - _min))

#     _distance = _distance * _weight
#     _entropy = torch.mean(_distance)


#     _features = F.normalize(features, dim=1)
#     _features_center = p[u_index, :]
#     _features = _features - _features_center
#     _features = _features.pow(2)
#     _tightness = torch.sum(_features, dim=1)
#     _mask = _tightness > 0
#     _tightness = _tightness[_mask]
#     _tightness = torch.sqrt(_tightness)
#     _tightness = torch.mean(_tightness)

#     return _tightness - _entropy

def ordinalentropy_tig_ent(features, gt,  mask=None):
    """
    Features: a certain layer's features
    gt: pixel-wise ground truth values, in depth estimation, gt.size()= n, h, w
    mask: In case values of some pixels do not exist. For depth estimation, there are some pixels lack the ground truth values
    """
    f_n, f_c = features.size()

    u_value, u_index, u_counts = torch.unique(gt, return_inverse=True, return_counts=True)
#     center_f = torch.zeros([len(u_value), f_c]).cuda()
#     for idx in range(len(u_value)):
#         center_f[idx, :] = torch.mean(features[u_index==idx, :], dim=0)

    center_f = torch.zeros([len(u_value), f_c]).cuda()
#     print(u_index)
    u_index = u_index.squeeze()
    u_index = u_index[:,0]
#     print(u_index.shape)
    center_f = center_f.index_add(0, u_index, features)
    u_counts = u_counts.unsqueeze(1)
    center_f = center_f / u_counts

    p = F.normalize(center_f, dim=1)
    _distance = euclidean_dist(p, p)
    _distance = up_triu(_distance)

    u_value = u_value.unsqueeze(1)
    _weight = euclidean_dist(u_value, u_value)
    _weight = up_triu(_weight)
    _max = torch.max(_weight)
    _min = torch.min(_weight)
    _weight = ((_weight - _min) / (_max - _min))

    _distance = _distance * _weight
    _entropy = torch.mean(_distance)


    _features = F.normalize(features, dim=1)
    _features_center = p[u_index, :]
#     print(_features.shape)
#     print(_features_center.shape)
    _features = _features - _features_center
    _features = _features.pow(2)
    _tightness = torch.sum(_features, dim=1)
    _mask = _tightness > 0
    _tightness = _tightness[_mask]
    _tightness = torch.sqrt(_tightness)
    _tightness = torch.mean(_tightness)

    return _tightness - _entropy

### the original implementation, which has an issue in normalizing data.
# def ordinal_entropy(features, gt):
#     """
#     Features: The last layer's features
#     gt: The corresponding ground truth values
#     """

#     """
#     sample in case the training size too large
#     """
#     # samples = random.sample(range(0, len(gt)-1), 100)  # random sample 100 features
# #     samples = random.sample(range(0, len(gt)-1), 10)  # random sample 100 features
# #     features = features[samples]
# #     gt = gt[samples]

#     """
#     calculate distances in the feature space, i.e. ||z_{c_i} - z_{c_j}||_2
#     """
#     p = F.normalize(features, dim=1)
#     _distance = euclidean_dist(p, p)
#     _distance = up_triu(_distance)

#     """
#     calculate the distances in the label space, i.e. w_{ij} = ||y_i -y_j||_2
#     """
#     _weight = euclidean_dist(gt, gt)
#     _weight = up_triu(_weight)
#     _max = torch.max(_weight)
#     _min = torch.min(_weight)
#     _weight = ((_weight - _min) / _max)

#     """
#     L_d = - mean(w_ij ||z_{c_i} - z_{c_j}||_2)
#     """
#     _distance = _distance * _weight
#     L_d = - torch.mean(_distance)

#     return L_d

def weight_ordinal_entropy(features, gt, weight):
    """
    Features: The last layer's features
    gt: The corresponding ground truth values
    """

    """
    sample in case the training size too large
    """
    # samples = random.sample(range(0, len(gt)-1), 100)  # random sample 100 features
#     samples = random.sample(range(0, len(gt)-1), 10)  # random sample 100 features
#     features = features[samples]
#     gt = gt[samples]

    """
    calculate distances in the feature space, i.e. ||z_{c_i} - z_{c_j}||_2
    """
    p = F.normalize(features, dim=1)
    _distance = euclidean_dist(p, p)
    _distance = up_triu(_distance)

    """
    calculate the distances in the label space, i.e. w_{ij} = ||y_i -y_j||_2
    """
    _weight = euclidean_dist(gt, gt)
    _weight = up_triu(_weight)
    _max = torch.max(_weight)
    _min = torch.min(_weight)
    _weight = ((_weight - _min) / (_max - _min))

    """
    L_d = - mean(w_ij ||z_{c_i} - z_{c_j}||_2)
    """
    _distance = _distance * _weight * weight
    L_d = - torch.mean(_distance)

    return L_d


def diffweight_ordinal_entropy(features, gt, weight):
    """
    Features: The last layer's features
    gt: The corresponding ground truth values
    """

    """
    sample in case the training size too large
    """
    # samples = random.sample(range(0, len(gt)-1), 100)  # random sample 100 features
#     samples = random.sample(range(0, len(gt)-1), 10)  # random sample 100 features
#     features = features[samples]
#     gt = gt[samples]

    """
    calculate distances in the feature space, i.e. ||z_{c_i} - z_{c_j}||_2
    """
    p = F.normalize(features, dim=1)
    _distance = euclidean_dist(p, p)
    _distance = up_triu(_distance)

    _distance = _distance * weight
    L_d = - torch.mean(_distance)

    return L_d


def ordinal_entropy(features, gt):
    """
    Features: The last layer's features
    gt: The corresponding ground truth values
    """

    """
    No need to sampel right now.
    """
    # samples = random.sample(range(0, len(gt)-1), 100)  # random sample 100 features
#     samples = random.sample(range(0, len(gt)-1), 10)  # random sample 100 features
#     features = features[samples]
#     gt = gt[samples]

    """
    calculate distances in the feature space, i.e. ||z_{c_i} - z_{c_j}||_2
    """
    p = F.normalize(features, dim=1)
    _distance = euclidean_dist(p, p)
    _distance = up_triu(_distance)

    """
    calculate the distances in the label space, i.e. w_{ij} = ||y_i -y_j||_2
    The computation of weight is correct here.
    """
    _weight = euclidean_dist(gt, gt)
    _weight = up_triu(_weight)
    _max = torch.max(_weight)
    _min = torch.min(_weight)
    _weight = ((_weight - _min) / (_max - _min))

    """
    L_d = - mean(w_ij ||z_{c_i} - z_{c_j}||_2)
    """
    _distance = _distance * _weight
    L_d = - torch.mean(_distance)

    return L_d

def ordinal_entropy_cosdis(features, gt):
    """
    Features: The last layer's features
    gt: The corresponding ground truth values
    """

    """
    sample in case the training size too large
    """
    # samples = random.sample(range(0, len(gt)-1), 100)  # random sample 100 features
#     samples = random.sample(range(0, len(gt)-1), 10)  # random sample 100 features
#     features = features[samples]
#     gt = gt[samples]

    """
    calculate distances in the feature space, i.e. ||z_{c_i} - z_{c_j}||_2
    """
    p = F.normalize(features, dim=1)
    _distance = euclidean_dist(p, p)
    _distance = up_triu(_distance)

    """
    calculate the distances in the label space, i.e. w_{ij} = ||y_i -y_j||_2
    """
    _weight = cosine_dist(gt, gt)
    _weight = up_triu(_weight)
    _max = torch.max(_weight)
    _min = torch.min(_weight)
    _weight = ((_weight - _min) / (_max - _min))

    """
    L_d = - mean(w_ij ||z_{c_i} - z_{c_j}||_2)
    """
    _distance = _distance * _weight
    L_d = - torch.mean(_distance)

    return L_d


def euclidean_dist(x, y):
    """
    Args:
      x: pytorch Variable, with shape [m, d]
      y: pytorch Variable, with shape [n, d]
    Returns:
      dist: pytorch Variable, with shape [m, n]
    """
    m, n = x.size(0), y.size(0)
    xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
    yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
    dist = xx + yy
    dist.addmm_(1, -2, x, y.t())
    dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
    return dist


def up_triu(x):
    # return a flattened view of up triangular elements of a square matrix
    n, m = x.shape
    assert n == m
    _tmp = torch.triu(torch.ones(n, n), diagonal=1).to(torch.bool)
    return x[_tmp]

import torch
import torch.nn.functional as F

def cosine_dist(x, y):
    """
    Computes the cosine similarity distance between two sets of vectors.

    Args:
      x: torch.Tensor of shape [m, d]
      y: torch.Tensor of shape [n, d]

    Returns:
      dist: torch.Tensor of shape [m, n] where each element is the cosine distance 
            (1 - cosine similarity) between a pair of vectors.
    """
    # Normalize each row (vector) in x and y.
    x_norm = F.normalize(x, p=2, dim=1)
    y_norm = F.normalize(y, p=2, dim=1)
    
    # Compute cosine similarity as the dot product between normalized vectors.
    cosine_similarity = torch.mm(x_norm, y_norm.t())
    
    # Convert cosine similarity to cosine distance.
    cosine_distance = 1 - cosine_similarity
    
    return cosine_distance


class BalancedPearsonCorrelationLoss(torch.nn.Module):
    """Pearson Corr balances between across gene and cell performance"""

    def __init__(
        self,
        rel_weight_gene: float = 1.0,
        rel_weight_cell: float = 1.0,
        norm_by = "mean",
        eps: float = 1e-8,
    ):
        """Initialise PearsonCorrelationLoss.

        Parameter
        ---------
        rel_weight_gene: float = 1.0
            The relative weight to put on the across gene/tss correlation.
        rel_weight_cell: float = 1.0
            The relative weight to put on the across cells correlation.
        norm_by:  Literal['mean', 'nonzero_median'] = 'nonzero_median'
            What to use as across gene / cell average to subtract from the
            signal to normalise it. Mean or the Median of the non zero entries.
        eps: float 1e-8
            epsilon
        """
        super().__init__()
        self.eps = eps
        self.norm_by = norm_by
        self.rel_weight_gene = rel_weight_gene
        self.rel_weight_cell = rel_weight_cell
        self.mse_loss = nn.MSELoss()

    def forward(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """Forward.

        Notes
        -----
        preds: torch.Tensor
            2D torch tensor [genes x cells], batched over genes.
        targets: torch.Tensor
            2D torch tensor [genes x cells], batched over genes.
        """
        if self.norm_by == "mean":
            preds_avg_gene = preds.mean(dim=0, keepdim=True)
            targets_avg_gene = targets.mean(dim=0, keepdim=True)
            preds_avg_cell = preds.mean(dim=1, keepdim=True)
            targets_avg_cell = targets.mean(dim=1, keepdim=True)
        else:
            preds_avg_gene = nonzero_median(preds, 0, keepdim=True)
            targets_avg_gene = nonzero_median(targets, 0, keepdim=True)
            preds_avg_cell = nonzero_median(preds, 1, keepdim=True)
            targets_avg_cell = nonzero_median(targets, 1, keepdim=True)

        r_tss = torch.nn.functional.cosine_similarity(
            preds - preds_avg_gene,
            targets - targets_avg_gene,
            eps=self.eps,
            dim=0,
        )

        r_celltype = torch.nn.functional.cosine_similarity(
            preds - preds_avg_cell,
            targets - targets_avg_cell,
            eps=self.eps,
        )

        loss = self.rel_weight_gene * (1 - r_tss.mean()) + self.rel_weight_cell * (
            1 - r_celltype.mean()
        )

        # norm the loss to 2 by half the sum of the relative weights
        loss = (loss * 2) / (self.rel_weight_gene + self.rel_weight_cell)

        return loss + self.mse_loss(preds,targets)