import torch
import torch.nn.functional as F
from torch import nn

class OnlineWeightLayer(nn.Module):
    def __init__(self, num_example, num_branch, scale=10) -> None:
        super(OnlineWeightLayer, self).__init__()
        self.num_example = num_example
        self.num_branch = num_branch
        init_weight = (torch.rand((num_example, num_branch)) - 0.5) * scale
        self.weight = torch.nn.Parameter(init_weight, requires_grad=True)
    
    def forward(self, index):
        '''
            output: (k, n)
        '''
        # return F.softmax(self.weight[index,:], dim=1).T
        return torch.sigmoid(self.weight[index, :]).T
        # return torch.clamp(F.tanh(self.weight[index, :]).T, max=0.999999, min=0.000001)
        # return F.hardtanh(self.weight[index, :], max_val=1.0, min_val=1e-6).T


class OnlineSoftmaxWeightLayer(nn.Module):
    def __init__(self, num_example, num_branch, scale=10) -> None:
        super(OnlineSoftmaxWeightLayer, self).__init__()
        self.num_example = num_example
        self.num_branch = num_branch
        init_weight = (torch.rand((num_example, num_branch)) - 0.5) * scale
        self.weight = torch.nn.Parameter(init_weight, requires_grad=True)
    
    def forward(self, index):
        '''
            output: (k, n)
        '''
        # partial normalize
        x = torch.softmax(self.weight[index, :], dim=1)
        # x = x * y
        # x = x / x.sum(dim=1, keepdim=True)

        return x.T

class OnlineMLPNet(nn.Module):
    def __init__(self, input, hidden, output) -> None:
        super(OnlineMLPNet, self).__init__()
        self.linear1 = nn.Linear(input, hidden)
        self.relu = nn.ReLU(inplace=True)
        self.linear2 = nn.Linear(hidden, output)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        out = self.linear2(x)
        return torch.softmax(out, dim=1).T