from typing import Optional

import torch


def margin_layer(model: torch.nn.Module,
                 x: torch.Tensor,
                 eps: float,
                 label: torch.Tensor,
                 subL: Optional[float] = None,
                 lambda_kl: float = 1.0,
                 return_loss: bool = True,
                 **kwargs):
    """
    Args:
        model (torch.nn.Module): the trained model.
        x (torch.Tensor): the input of the model.
        label (torch.Tensor): the target of the model.
        eps (float): the robustness radius.
        subL (float, optional): If not None, it is used for testing when we
            will compute the sub_lipschitz ahead instead of computing it at
            every call.
    """
    y = model(x)

    if hasattr(model, 'module'):
        K = model.module.sub_lipschitz() if subL is None else subL
        head = model.module.head.get_weight()
    else:
        K = model.sub_lipschitz() if subL is None else subL
        head = model.head.get_weight()

    pred = y.argmax(1)
    head_j = head[pred].unsqueeze(1)  # batch, 1, dim
    head_ji = head_j - head.unsqueeze(0)  # batch, num_class, dim
    head_ji = head_ji.norm(dim=-1)  # batch, num_class
    y_ = y + K * eps * head_ji
    y_ = y_.scatter(1, pred.view(-1, 1), -10**8.)
    y_ = y_.max(1)[0].reshape(-1, 1)
    y_ = torch.cat([y, y_], dim=1)
    if return_loss:
        loss = torch.nn.CrossEntropyLoss()(y, label)
        KL_loss = y.log_softmax(dim=-1)[:, 0]
        KL_loss = KL_loss - y_.log_softmax(dim=-1)[:, 0]
        KL_loss = KL_loss.mean()
        loss = loss + KL_loss * lambda_kl
    else:
        loss = None
    return y, y_, loss


def margin_layer_v2(model: torch.nn.Module,
                    x: torch.Tensor,
                    label: torch.Tensor,
                    eps: float,
                    subL: Optional[float] = None,
                    return_loss: bool = True,
                    **kwargs):
    """
    Args:
        model (torch.nn.Module): the trained model.
        x (torch.Tensor): the input of the model.
        label (torch.Tensor): the target of the model.
        eps (float): the robustness radius.
        module (bool): If true, it indicates the model is wrapped by DP/DDP.
            Use `model.module` instead of `model` to call functions.
        subL (float, optional): If not None, it is used for testing when we
            will compute the sub_lipschitz ahead instead of computing it at
            every call.
    """
    y = model(x)

    if hasattr(model, 'module'):
        K = model.module.sub_lipschitz() if subL is None else subL
        head = model.module.head.get_weight()
    else:
        K = model.sub_lipschitz() if subL is None else subL
        head = model.head.get_weight()

    y = y - y.gather(dim=1, index=label.reshape(-1, 1))
    head_j = head[label].unsqueeze(1)  # batch, 1, dim
    head_ji = head_j - head.unsqueeze(0)  # batch, num_class, dim
    head_ji = head_ji.norm(dim=-1)  # batch, num_class
    margin = K * head_ji
    eps_ji = (-y / margin.clip(1e-7)).detach()
    y_ = y + eps_ji.clip(eps * 0.05, eps) * margin
    if return_loss:
        loss = torch.nn.CrossEntropyLoss()(y_, label)
    else:
        loss = None
    return y, y_, loss
