import numpy as np
import torch


def sharpen_prob(p, temperature=2):
    """Sharpening probability with a temperature.

    Args:
        p (torch.Tensor): probability matrix (batch_size, n_classes)
        temperature (float): temperature.
    """
    p = p.pow(temperature)
    return p / p.sum(1, keepdim=True)


def reverse_index(data, label):
    """Reverse order."""
    inv_idx = torch.arange(data.size(0) - 1, -1, -1).long()
    return data[inv_idx], label[inv_idx]


def shuffle_index(data, label):
    """Shuffle order."""
    rnd_idx = torch.randperm(data.shape[0])
    return data[rnd_idx], label[rnd_idx]


def create_onehot(label, num_classes):
    """Create one-hot tensor.

    We suggest using nn.functional.one_hot.

    Args:
        label (torch.Tensor): 1-D tensor.
        num_classes (int): number of classes.
    """
    onehot = torch.zeros(label.shape[0], num_classes)
    onehot = onehot.scatter(1, label.unsqueeze(1).data.cpu(), 1)
    onehot = onehot.to(label.device)
    return onehot


def sigmoid_rampup(current, rampup_length):
    """Exponential rampup.

    Args:
        current (int): current step.
        rampup_length (int): maximum step.
    """
    assert rampup_length > 0
    current = np.clip(current, 0.0, rampup_length)
    phase = 1.0 - current/rampup_length
    return float(np.exp(-5.0 * phase * phase))


def linear_rampup(current, rampup_length):
    """Linear rampup.

    Args:
        current (int): current step.
        rampup_length (int): maximum step.
    """
    assert rampup_length > 0
    ratio = np.clip(current / rampup_length, 0.0, 1.0)
    return float(ratio)


def ema_model_update(model, ema_model, alpha):
    """Exponential moving average of model parameters.

    Args:
        model (nn.Module): model being trained.
        ema_model (nn.Module): ema of the model.
        alpha (float): ema decay rate.
    """
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)
