"""
The code has been modified from https://github.com/wOOL/DNDT.
"""

from functools import reduce

import numpy as np
import torch
from torch import nn


def read_importance(path, dataset):
    with open(path) as f:
        importance = None
        for line in f:
            words = line.split('\t')
            if dataset == words[0]:
                importance = np.array([float(e) for e in words[1:]])
                break
    return importance


def product(a, b):
    return torch.einsum('ij,ik->ijk', [a, b]).view(a.size(0), -1)


class FeatureMask(nn.Module):
    def __init__(self, in_features, max_features, importance):
        super().__init__()
        index = np.argsort(-importance)[:max_features]
        mask = np.zeros((in_features, max_features), dtype=np.float32)
        mask[index, np.arange(max_features)] = 1
        self.mask = nn.Parameter(torch.from_numpy(mask), requires_grad=False)

    def forward(self, x):
        return torch.matmul(x, self.mask)


class DNDT(nn.Module):
    def __init__(self, in_features, num_classes, num_cuts=1, temperature=1.,
                 max_features=20, dataset=None, selection=None):
        super().__init__()

        # Set the feature mask if necessary.
        self.mask = None
        if in_features > max_features \
                and dataset is not None and selection is not None:
            importance = read_importance(selection, dataset)
            self.mask = FeatureMask(in_features, max_features, importance)
            in_features = max_features

        num_leaves = (num_cuts + 1) ** in_features
        self.loss = nn.CrossEntropyLoss()
        self.temperature = temperature
        self.zero = nn.Parameter(
            torch.zeros([in_features, 1]), requires_grad=False)
        self.weight = nn.Parameter(
            torch.arange(num_cuts + 1).to(torch.float32).unsqueeze(0) + 1,
            requires_grad=False)
        self.cut_points = nn.Parameter(
            torch.rand(in_features, num_cuts), requires_grad=True)
        self.leaf = nn.Linear(num_leaves, num_classes, bias=False)

    def forward(self, x):
        if self.mask is not None:
            x = self.mask(x)

        bias = torch.sort(self.cut_points, dim=1)[0]
        bias = torch.cat([self.zero, bias], dim=1)  # Append zero at the first.
        bias = torch.cumsum(bias, dim=1)
        h = torch.matmul(x.unsqueeze(2), self.weight) - bias
        h = torch.softmax(h / self.temperature, dim=2)
        leaf = reduce(product, torch.unbind(h, dim=1))
        return self.leaf(leaf)
