import torch, torch.nn.functional as F


"""================================================================================================="""
ALLOWED_MINING_OPS  = None
REQUIRES_BATCHMINER = False
REQUIRES_OPTIM      = False


@torch.jit.script
def _orthogonalize(matrix):
    n, m = matrix.shape
    for i in range(m):
        # Normalize the i'th column
        col = matrix[:, i : i + 1]
        col /= torch.sqrt(torch.sum(col ** 2))
        # Project it on the rest and remove it
        if i + 1 < m:
            rest = matrix[:, i + 1 :]
            # rest -= torch.matmul(col.t(), rest) * col
            rest -= torch.sum(col * rest, dim=0) * col

class Criterion(torch.nn.Module):
    def __init__(self, opt):
        """
        Args:
            opt: Namespace containing all relevant parameters.
        """
        super(Criterion, self).__init__()

        ####
        self.opt = opt
        self.num_bases = opt.loss_labelcorr_num_bases
        self.use_svd = not opt.loss_labelcorr_not_svd
        # self.correlation_measure ## TODO
        # self.internal_loss = "mse" ## TODO

        ## check that loss includes proxy information
        assert ("proxy" in opt.loss and "s2sd" not in opt.loss) or ("proxy" in opt.loss_s2sd_source and "s2sd" in opt.loss), "Label correlation intended for use with proxy embeddings."

        self.name = "labelcorr"

    def _svd(self, batch, num_bases, power_iter=1):
        num_p = batch.shape[1]
    
        num_bases = min(num_bases, num_p)
        L = torch.normal(0, 1.0, size=(num_p, num_bases), device=batch.device)
        for _ in range(power_iter):
            R = torch.matmul(batch, L) # n x k
            L = torch.matmul(batch.T, R) # p x k
            _orthogonalize(L)
        return L
    
    def forward(self, batch, labels, proxies, **kwargs):
        ## Compute singular values of batch
        batch = batch.detach()
        labels = labels.detach().float().to(batch.device)

        with torch.no_grad():
            if self.use_svd:
                L = self._svd(batch, self.num_bases)
                # U, s, _ = torch.linalg.svd(batch.detach(), full_matrices=False)
                # U = U[:self.num_bases]
                # s = s[:self.num_bases]
                # Us = U * s
            ## 
            else:
                L = torch.eye(proxies.shape[-1], device=proxies.device)
            
            ## standardize
            mean_labels = labels.mean(dim=-1, keepdim=True)
            slabels = F.normalize(labels - mean_labels, dim=-1)
            
            ## r = (1 - d^2 / 2)
            r = torch.corrcoef(slabels.T) ## pearson coefficient
            r = r[torch.tril_indices(*r.shape, offset=-1).unbind()] ## reshape
            d = torch.sqrt(2 * (1-r))

        proxy_embeds = torch.matmul(proxies, L)

        ## standardize proxies so that Euc distance <==> pearson relationship holds
        mean_proxy_embeds = proxy_embeds.mean(dim=-1, keepdim=True)
        proxy_embeds = F.normalize(proxy_embeds - mean_proxy_embeds, dim=-1) 
        pdists = F.pdist(proxy_embeds)

        ## mean squared loss between r and rprime
        loss = F.mse_loss(pdists, d)
        
        return loss
