from typing import Callable, Optional, Union
import torch
import torch.nn.functional as F

# For each discriminatory class, orthogonalize samples
def end_orthogonal(gram: torch.Tensor, bias_labels: torch.Tensor, exp=False, t=0.07) -> torch.Tensor:
    """For each discriminatory class, orthogonalize samples.

    Args:
        gram (torch.Tensor):
        bias_labels (torch.Tensor):

    Returns:
        torch.Tensor: orthogonal loss.
    """
    bias_classes = torch.unique(bias_labels)
    orthogonal_loss = torch.tensor(0., device=gram.device)
    M_tot = 0.
    
    for bias_class in bias_classes:
        bias_mask = (bias_labels == bias_class).type(torch.float).unsqueeze(dim=1)
        bias_mask = torch.tril(torch.mm(bias_mask, torch.transpose(bias_mask, 0, 1)), diagonal=-1)
        M = bias_mask.sum()
        M_tot += M
        
        if M > 0:
            if exp:
                orthogonal_loss += torch.log(torch.sum(torch.exp(torch.abs(gram)/t) * bias_mask))
            else:
                orthogonal_loss += torch.sum(torch.abs(gram * bias_mask))

    if M_tot > 0:
        orthogonal_loss /= M_tot
    return orthogonal_loss

def end_orthogonal_sup(gram: torch.Tensor, target_labels: torch.Tensor, bias_labels: torch.Tensor, exp=False, t=0.07) -> torch.Tensor:
    """For each discriminatory class, orthogonalize samples.

    Args:
        gram (torch.Tensor):
        bias_labels (torch.Tensor):

    Returns:
        torch.Tensor: orthogonal loss.
    """
    target_classes = torch.unique(target_labels)
    bias_classes = torch.unique(bias_labels)
    orthogonal_loss = torch.tensor(0., device=gram.device)
    M_tot = 0.
    
    for target_class in target_classes:
        class_mask = (target_labels == target_class).type(torch.float).unsqueeze(dim=1)
        class_mask = torch.tril(torch.mm(class_mask, torch.transpose(class_mask, 0, 1)), diagonal=-1)

        for bias_class in bias_classes:
            bias_mask = (bias_labels == bias_class).type(torch.float).unsqueeze(dim=1)
            bias_mask = torch.tril(torch.mm(bias_mask, torch.transpose(bias_mask, 0, 1)), diagonal=-1)            
            bias_mask = class_mask * bias_mask
            
            M = bias_mask.sum()
            M_tot += M
            
            if M > 0:
                if exp:
                    orthogonal_loss += torch.log(torch.sum(torch.exp(torch.abs(gram)/t) * bias_mask))
                else:
                    orthogonal_loss += torch.sum(torch.abs(gram * bias_mask))

    if M_tot > 0:
        orthogonal_loss /= M_tot
    return orthogonal_loss

# For each target class, parallelize samples belonging to
# different discriminatory classes
def end_parallel(gram: torch.Tensor, target_labels: torch.Tensor, 
                 bias_labels: torch.Tensor, max_val=1, exp=False, t=1) -> torch.Tensor:
    """For each target class, parallelize samples belonging to different discriminatory classes.

    Args:
        gram (torch.Tensor): Gram matrix.
        target_labels (torch.Tensor): target labels.
        bias_labels (torch.Tensor): bias labels.

    Returns:
        torch.Tensor: parallel loss.

    """
    target_classes = torch.unique(target_labels)
    bias_classes = torch.unique(bias_labels)
    parallel_loss = torch.tensor(0.).to(gram.device)
    M_tot = 0.
    
    if exp:
        max_val /= t

    for target_class in target_classes:
        class_mask = (target_labels == target_class).type(torch.float).unsqueeze(dim=1)
        
        for idx, bias_class in enumerate(bias_classes):
            bias_mask = (bias_labels == bias_class).type(torch.float).unsqueeze(dim=1)
            
            for other_bias_class in bias_classes[idx:]:
                if other_bias_class == bias_class:
                    continue
                
                other_bias_mask = (bias_labels == other_bias_class).type(torch.float).unsqueeze(dim=1)
                mask = torch.tril(torch.mm(class_mask * bias_mask, torch.transpose(class_mask * other_bias_mask, 0, 1)),
                                  diagonal=-1)
                M = mask.sum()
                M_tot += M
                
                if M > 0:
                    if exp:
                        parallel_loss += torch.sum(((max_val + gram/t) * mask) / (max_val*2.0))
                    else:
                        parallel_loss += torch.sum(((max_val + gram) * mask) / (max_val*2.0))
    
    if M_tot > 0:
        parallel_loss = max_val - (parallel_loss / M_tot)
    
    return parallel_loss

def end_parallel_weighted(gram: torch.Tensor, target_labels: torch.Tensor, 
                          bias_labels: torch.Tensor, max_val=1, exp=False, t=1) -> torch.Tensor:
    """Parallelize samples belonging to different discriminatory classes,
       weighted by target similarity

    Args:
        gram (torch.Tensor): Gram matrix.
        target_labels (torch.Tensor): target labels.
        bias_labels (torch.Tensor): bias labels.

    Returns:
        torch.Tensor: parallel loss.

    """
    bias_classes = torch.unique(bias_labels)
    parallel_loss = torch.tensor(0.).to(gram.device)
    M_tot = 0.
    
    if exp:
        max_val /= t

    assert len(target_labels.shape) == 1
    
    kernel_dist = torch.cdist(target_labels[:, None], target_labels[:, None])
    kernel_dist = (1. / (kernel_dist + 1e-5))
    kernel_dist /= kernel_dist.max()

    for idx, bias_class in enumerate(bias_classes):
        bias_mask = (bias_labels == bias_class).type(torch.float).unsqueeze(dim=1)
        
        for other_bias_class in bias_classes[idx:]:
            if other_bias_class == bias_class:
                continue
            
            other_bias_mask = (bias_labels == other_bias_class).type(torch.float).unsqueeze(dim=1)
            mask = torch.mm(bias_mask, torch.transpose(other_bias_mask, 0, 1))
            assert mask.shape == kernel_dist.shape
            mask = torch.tril(mask * kernel_dist, diagonal=-1)

            M_tot += (mask != 0).sum()
            
            # parallel_loss += torch.sum(mask)
            if exp:
                parallel_loss += torch.sum(((max_val + gram/t) * mask) / (max_val*2.0))
            else:
                parallel_loss += torch.sum(((max_val + gram) * mask) / (max_val*2.0))
    
    if M_tot > 0:
        parallel_loss = max_val - (parallel_loss / M_tot)
    
    return parallel_loss

def correlation_matrix(feats):
    if len(feats.shape) == 2:
        # Bx(2048*N)
        G = torch.mm(feats, torch.transpose(feats, 0, 1))    
    
    elif len(feats.shape) == 3:    
        feats = feats.transpose(0, 1) # Bx2048xN -> 2048xBxN
        G = torch.matmul(feats, feats.transpose(1, 2)) # 2048xBxB
        G = G.permute(1, 2, 0) # -> BxBx2048

    else:
        print('Activation must be 3D or 2D matrix')
        exit(1)

    return G

def cosine_similarity_matrix(feats, eps=1e-8):
    if len(feats.shape) > 2:
        print('Cosine similarity not supported when per_channel=True')
        exit(1)

    prod = correlation_matrix(feats)
    norm = torch.norm(feats, dim=1, p='fro')[None, :]
    norm = torch.mm(norm.T, norm)
    
    G = prod / torch.max(norm, torch.tensor(eps, device=feats.device))
    return G

def compute_gram(feats, metric, per_channel, gram_mean_abs, 
                 occurence_matrix=None, bias_labels=None, target_labels=None):
    
    # Compute weights for the similarity scores
    # based on the co-occurence matrix
    weights, max_val = None, 1.
    if occurence_matrix is not None:
        # If occurence[i, j] = 0 then a nan wont be a problem
        inv_freqs = 1. / occurence_matrix

        # inv_freqs is bounded between [min, 1] so no need to normalize
        # inv_freqs /= inv_freqs.max()
        
        # Select the weights for the current batch
        weights = inv_freqs[bias_labels, target_labels].unsqueeze(1)

        # To prevent very low weights, might need to enable clamp?
        # weights = torch.clamp(weights, 0.1)
    
    # if not per_channel:
    if metric == 'correlation':
        gram_matrix = correlation_matrix(feats)
    elif metric == 'cosine':
        gram_matrix = cosine_similarity_matrix(feats)
    else:
        raise ValueError(f'Unknown metric choice {metric}')
    
    # not really needed, just for safety for approximate repr
    gram_matrix = torch.clamp(gram_matrix, -1, 1.)

    if per_channel:
        if gram_mean_abs:
            gram_matrix = gram_matrix.abs()
        gram_matrix = gram_matrix.mean(dim=-1)

    if weights is not None:
        weights = torch.mm(weights, weights.T)
        gram_matrix = gram_matrix * weights
    
    return gram_matrix, max_val

def end_loss(feats: torch.Tensor,
             target_labels: torch.Tensor,
             bias_labels: torch.Tensor,
             alpha: Optional[float] = 1.0,
             beta: Optional[float] = 1.0,
             sup: bool = False,
             exp: bool = False,
             temperature: float = 0.07,
             kernel: bool = False,
             metric: str = "correlation",
             reduction: Optional[Callable] = None,
             per_channel: bool = False,
             sum: Optional[bool] = True,
             gram_mean_abs: bool = False,
             occurence_matrix: Optional[torch.Tensor] = None) -> Union[torch.Tensor, tuple]:
    """Computes EnD regularization term: https://arxiv.org/abs/2103.02023.

    Examples:
        ```python

        import EnD


        model = resnet18()
        model.avgpool = nn.Sequential(
            model.avgpool,
            EnD.Normalize()
        )
        hook = EnD.Hook(model.avgpool, backward=False)

        . . .
        def criterion(outputs, target, bias_labels):
        ce = F.cross_entropy(outputs, target)
        end = EnD.end_regu(hook.output, target, bias_labels, alpha=0.1, beta=0.1)
        return ce + end
        ```

    Args:
        hook (Hook): forward hook applied on the desired layer.
        target_labels (torch.Tensor): ground truth labels of the batch.
        bias_labels (torch.Tensor): bias labels given for the current batch.
        alpha (float, optional): weight of the disentangling term. Defaults to 1.0.
        beta (float, optional): weight of the entangling. Defaults to 1.0.
        sum (bool, optional): if False, returns the contributions of the two terms separately, otherwise sum. Defaults to True.

    Returns:
        Union[torch.Tensor, tuple]: value of the EnD term.
    """
    
    # Apply reduction (i.e. pooling, mean, etc)
    if reduction is not None:
        feats = reduction(feats)
    
    if not per_channel and len(feats.size()) > 2:
        feats = feats.view(feats.shape[0], -1)
    elif per_channel and len(feats.size()) > 2:
        feats = feats.view(feats.shape[0], feats.shape[1], -1)

    gram_matrix, max_val = compute_gram(feats, metric, per_channel,
                                        gram_mean_abs=gram_mean_abs, 
                                        occurence_matrix=occurence_matrix,
                                        bias_labels=bias_labels,
                                        target_labels=target_labels)
    # Just the lower triangluar is needed 
    gram_matrix = torch.tril(gram_matrix, diagonal=-1)

    zero = torch.tensor(0., device=feats.device)

    if sup:
        R_ortho = end_orthogonal_sup(gram_matrix, target_labels, bias_labels, 
                                     exp=exp, t=temperature)
    else:
        R_ortho = end_orthogonal(gram_matrix, bias_labels, exp=exp, t=temperature) if alpha != 0 else zero
    
    if kernel:
        R_parallel = end_parallel_weighted(gram_matrix, target_labels, bias_labels, 
                                           max_val, exp=exp, t=temperature) if beta != 0 else zero
    else:
        R_parallel = end_parallel(gram_matrix, target_labels, bias_labels, max_val, 
                                  exp=exp, t=temperature) if beta != 0 else zero

    if sum:
        return alpha * R_ortho + beta * R_parallel
    return alpha * R_ortho, beta * R_parallel
