import numpy as np
import torch
import torch.nn as nn


def to_one_hot(label, num):
    vec = torch.zeros(label.size(0), num).float().cuda()
    vec[torch.arange(label.size(0)), label % num] = 1
    return vec


def weight_initialize_nn(model, sigma=None, a=None, gain=None):
    for m in model.modules():
        if isinstance(m, nn.Linear):
            if sigma is not None:
                nn.init.xavier_normal_(m.weight, sigma)
            if a is not None:
                nn.init.kaiming_uniform_(m.weight, a=a, nonlinearity='relu')
            if gain is not None:
                nn.init.xavier_uniform_(m.weight, gain=gain)
            m.bias.data.zero_()


def generate_noise(mean, std, *size):
    return torch.normal(torch.ones(size).float().cuda() * mean, torch.ones(size).float().cuda() * std)


def argmax_uniform(x, dim=-1):
    max_x = torch.max(x, dim=dim)[0]
    max_mask = (x == max_x.unsqueeze(dim))
    return torch.argmax(max_mask.float() * torch.rand(max_mask.size()).float().cuda(), dim=dim)


def argmax_minimum(x, dim=-1):
    max_x = torch.max(x, dim=dim)[0]
    max_mask = (x == max_x.unsqueeze(dim))
    return torch.argmax(max_mask.float() *
                        (max_mask.size(dim) - torch.arange(max_mask.size(dim)).float().cuda()), dim=dim)


def float2str(f):
    if int(f * 10) % 10 == 0:
        return str(int(f))
    else:
        return str(int(f)) + "_" + str(int(f * 10) % 10)
