import torch
from torch import nn


class Root(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x, torch.ones((x.size(0), 1), device=x.device)


class Layer(nn.Module):
    def __init__(self, in_features, num_nodes):
        super().__init__()
        self.layer = nn.Linear(in_features, num_nodes)

    def forward(self, inputs):
        x, state = inputs
        prob_right = torch.sigmoid(self.layer(x))
        prob_left = 1 - prob_right
        new_prob = torch.stack((prob_left, prob_right), dim=2)  # N x D x 2
        new_state = state.unsqueeze(2).repeat((1, 1, 2)) * new_prob
        new_state = new_state.view((new_state.size(0), -1))  # N x 2D
        return x, new_state


class SDT(nn.Module):
    def __init__(self, in_features, num_classes, num_layers=8):
        super().__init__()

        num_leaves = int(2 ** num_layers)
        self.logit = nn.Parameter(
            torch.randn(num_leaves, num_classes), requires_grad=True)
        self.num_leaves = num_leaves

        layers = [Root()]
        for d in range(num_layers):
            layers.append(Layer(in_features, int(2 ** d)))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        final_state = self.layers(x)[1]
        y_predicted = self.logit.expand((x.size(0), -1, -1)).softmax(dim=2)
        return (y_predicted * final_state.unsqueeze(2)).sum(dim=1)
