from typing import Optional

import torch


def margin_layer(model: torch.nn.Module,
                 x: torch.Tensor,
                 eps: float,
                 label: torch.Tensor,
                 module: bool = False,
                 subL: Optional[float] = None,
                 lambda_kl: float = 1.0,
                 use_lln: bool = True,
                 return_loss: bool = True,
                 **kwargs):
    """
    Args:
        model (torch.nn.Module): the trained model.
        x (torch.Tensor): the input of the model.
        label (torch.Tensor): the target of the model.
        eps (float): the robustness radius.
        module (bool): If true, it indicates the model is wrapped by DP/DDP.
            Use `model.module` instead of `model` to call functions.
        use_lln (bool): If true, use last layer normalization.
        subL (float, optional): If not None, it is used for testing when we
            will compute the sub_lipschitz ahead instead of computing it at
            every call.
    """
    y = model(x)
    if module:
        K = model.module.sub_lipschitz() if subL is None else subL
        head = model.module.head.weight
    else:
        K = model.sub_lipschitz() if subL is None else subL
        head = model.head.weight  # num_class, dim

    if use_lln:
        head = torch.nn.functional.normalize(head, dim=1)

    pred = y.argmax(1)
    head_j = head[pred].unsqueeze(1)  # batch, 1, dim
    head_ji = head_j - head.unsqueeze(0)  # batch, num_class, dim
    head_ji = head_ji.norm(dim=-1)  # batch, num_class
    y_ = y + K * eps * head_ji
    y_ = y_.scatter(1, pred.view(-1, 1), -10**8.)
    y_ = y_.max(1)[0].reshape(-1, 1)
    y_ = torch.cat([y, y_], dim=1)
    if return_loss:
        loss = torch.nn.CrossEntropyLoss()(y, label)
        KL_loss = y.log_softmax(dim=-1)[:, 0]
        KL_loss = KL_loss - y_.log_softmax(dim=-1)[:, 0]
        KL_loss = KL_loss.mean()
        loss = loss + KL_loss * lambda_kl
    else:
        loss = None
    return y, y_, loss


def margin_layer_v2(model: torch.nn.Module,
                    x: torch.Tensor,
                    label: torch.Tensor,
                    eps: float,
                    module: bool = False,
                    subL: Optional[float] = None,
                    use_lln: bool = True,
                    return_loss: bool = True,
                    **kwargs):
    """
    Args:
        model (torch.nn.Module): the trained model.
        x (torch.Tensor): the input of the model.
        label (torch.Tensor): the target of the model.
        eps (float): the robustness radius.
        module (bool): If true, it indicates the model is wrapped by DP/DDP.
            Use `model.module` instead of `model` to call functions.
        use_lln (bool): If true, use last layer normalization.
        subL (float, optional): If not None, it is used for testing when we
            will compute the sub_lipschitz ahead instead of computing it at
            every call.
    """
    y = model(x)
    if module:
        K = model.module.sub_lipschitz() if subL is None else subL
        head = model.module.head.weight
    else:
        K = model.sub_lipschitz() if subL is None else subL
        head = model.head.weight  # num_class, dim

    if use_lln:
        head = torch.nn.functional.normalize(head, dim=1)

    head_j = head[label].unsqueeze(1)  # batch, 1, dim
    head_ji = head_j - head.unsqueeze(0)  # batch, num_class, dim
    head_ji = head_ji.norm(dim=-1)  # batch, num_class
    y_ = y + K * eps * head_ji
    if return_loss:
        loss = torch.nn.CrossEntropyLoss()(y_, label)
    else:
        loss = None
    return y, y_, loss


def _margin_layer_deprecated(model, x, eps, module=False):
    y = model(x)
    pred = y.argmax(1)
    if module:
        K = model.module.lipschitz()
    else:
        K = model.lipschitz()

    y_ = y + K.view(1, -1) * eps
    y_ = y_ + K[pred].view(-1, 1) * eps
    y_ = y_.scatter(1, pred.view(-1, 1), -10**8.)
    y_ = y_.max(1)[0].reshape(-1, 1)
    y_ = torch.cat([y, y_], dim=1)
    return y, y_
