import numpy as np
import torch
from torch import nn
from torch.nn import functional as F

# def _batch_mahalanobis(bL, bx):
#     r"""
#     Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}`
#     for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`.

#     Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch
#     shape, but `bL` one should be able to broadcasted to `bx` one.
#     """
#     n = bx.size(-1)
#     bx_batch_shape = bx.shape[:-1]

# #     # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
# #     # we are going to make bx have shape (..., 1, j,  i, 1, n) to apply batched tri.solve
# #     bx_batch_dims = len(bx_batch_shape)     #1
# #     bL_batch_dims = bL.dim() - 2            #0
# #     outer_batch_dims = bx_batch_dims - bL_batch_dims     #1
# #     old_batch_dims = outer_batch_dims + bL_batch_dims    #1
# #     new_batch_dims = outer_batch_dims + 2 * bL_batch_dims   #1
# #     # Reshape bx with the shape (..., 1, i, j, 1, n)
# #     bx_new_shape = bx.shape[:outer_batch_dims]           #M
# #     for (sL, sx) in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]):
# #         bx_new_shape += (sx // sL, sL)
# #     bx_new_shape += (n,)
# #     bx = bx.reshape(bx_new_shape)           #MxD
# #     # Permute bx to make it have shape (..., 1, j, i, 1, n)
# #     permute_dims = (list(range(outer_batch_dims)) +
# #                     list(range(outer_batch_dims, new_batch_dims, 2)) +
# #                     list(range(outer_batch_dims + 1, new_batch_dims, 2)) +
# #                     [new_batch_dims])
# #     bx = bx.permute(permute_dims)

#     flat_L = bL.reshape(-1, n, n)  # shape = b x n x n  1xDxD
#     flat_x = bx.reshape(-1, flat_L.size(0), n)  # shape = c x b x n     Mx1xD
#     flat_x_swap = flat_x.permute(1, 2, 0)  # shape = b x n x c     1xDxM
#     M_swap = torch.triangular_solve(flat_x_swap, flat_L, upper=False)[0].pow(2).sum(-2)  # shape = b x c   1xM
#     M = M_swap.t()  # shape = c x b   Mx1

# #     # Now we revert the above reshape and permute operators.
# #     permuted_M = M.reshape(bx.shape[:-1])  # shape = (..., 1, j, i, 1)   M
# #     permute_inv_dims = list(range(outer_batch_dims))
# #     for i in range(bL_batch_dims):
# #         permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i]
# #     reshaped_M = permuted_M.permute(permute_inv_dims)  # shape = (..., 1, i, j, 1)
# #     return reshaped_M.reshape(bx_batch_shape)

#     return M.reshape(bx_batch_shape)

def _batch_mahalanobis(bL, bx):
    r"""
    Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}`
    for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`.

    Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch
    shape, but `bL` one should be able to broadcasted to `bx` one.
    """
    n = bx.size(-1)
    flat_L = torch.diag(bL).reshape(1, n)    # D
    flat_x = bx.reshape(-1, n)
    
    inv_L = 1/torch.square(flat_L)
    M = torch.square(flat_x) * inv_L
    M = torch.sum(M, dim=1)
    
    return M
    
# def nll(feats, labels, means, scale_tril):
#     '''
#     only on new classes, new dimension
#     feats: MxD
#     labels: M
#     means: CxD
#     cov: DxD
#     '''
#     sample_means = means[labels]     # MxD
#     diff = feats.detach() - sample_means      # MxD
    
# #     scale_tril = torch.linalg.cholesky(covariance_matrix)
    
#     M = _batch_mahalanobis(scale_tril, diff)
#     half_log_det = scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
    
#     return half_log_det + 0.5 * torch.mean(M)
    
# #     return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + M) - half_log_det
    
def nll(x, scale):
    '''
    only on new classes, new dimension
    feats: MxD
    labels: M
    means: CxD
    cov: DxD
    '''
    var = (scale ** 2)
    log_scale = scale.log()
    return torch.mean((x.detach() ** 2) / (2 * var) + log_scale) 
    
# def md(feats, labels, means, scale_tril):
#     sample_means = means[labels]     # MxD
#     diff = feats - sample_means      # MxD
    
# #     scale_tril = torch.linalg.cholesky(covariance_matrix.detach())
    
#     M = _batch_mahalanobis(scale_tril.detach(), diff)
    
#     return 0.5 * torch.mean(M)

def md(x, scale):
    var = (scale.detach() ** 2)
    
    return torch.mean((x ** 2) / (2 * var))

def ccg(feats, labels, loc, scale, lam, eps=1e-7):
    sample_loc = loc[labels] 
    sample_scale = scale[labels]
    sim = torch.bmm(F.normalize(feats, p=2, dim=-1).unsqueeze(1), F.normalize(sample_loc, p=2, dim=-1).unsqueeze(2))
    theta = torch.acos(torch.clamp(sim.squeeze(), -1 + eps, 1 - eps))
    
    return lam * nll(theta, sample_scale) + md(theta, sample_scale)
#     return lam * nll(theta, sample_scale)

def ccg_inv(feats_neg, feats_pos, labels, loc, scale, lam, eps=1e-7):
    sample_loc = loc[labels] 
    sample_scale = scale[labels]
    sim_neg = torch.bmm(F.normalize(feats_neg, p=2, dim=-1).unsqueeze(1), F.normalize(sample_loc.detach(), p=2, dim=-1).unsqueeze(2))
    theta_neg = torch.acos(torch.clamp(sim_neg.squeeze(), -1 + eps, 1 - eps))
    sim_pos = torch.bmm(F.normalize(feats_pos, p=2, dim=-1).unsqueeze(1), F.normalize(sample_loc.detach(), p=2, dim=-1).unsqueeze(2))
    theta_pos = torch.acos(torch.clamp(sim_pos.squeeze(), -1 + eps, 1 - eps))
    
    return torch.relu(md(theta_pos, sample_scale) - md(theta_neg, sample_scale) + lam)
#     return -torch.mean((theta ** 2) / (2 * var)) 

# def ccg(feats, labels, means, scale_tril, lam):
#     feats = F.normalize(feats, p=2, dim=1)
#     return lam * nll(feats, labels, means, scale_tril) + md(feats, labels, means, scale_tril)
