import numpy as np
import torch
from torch import nn


def to_activation(name, dim):
    if name == 'softmax':
        return nn.Softmax(dim)
    elif name == 'spherical':
        return SphericalSoftmax(dim)
    elif name == 'sigsoftmax':
        return SigSoftmax(dim)
    else:
        raise ValueError(name)


def to_masking_value(name):
    if name == 'softmax':
        return -np.inf
    elif name == 'spherical':
        return 0
    elif name == 'sigsoftmax':
        return -np.inf
    else:
        raise ValueError(name)


def normalize_scores(scores, bias, spherical, mode='positive'):
    if spherical:
        scores *= (scores.sum(2) + bias).sign().unsqueeze(2)

    if mode == 'positive':
        scores[scores < 0] = 0
    elif mode == 'negative':
        scores[scores > 0] = 0
    elif mode == 'both':
        pass
    else:
        raise ValueError(mode)

    scores[scores.abs().sum(2) == 0] = 1
    map_sum = scores.abs().sum(2, keepdim=True)
    return scores / map_sum


class SigSoftmax(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.epsilon = 1e-12
        self.sigmoid = nn.LogSigmoid()
        self.softmax = nn.Softmax(dim)

    def forward(self, x):
        return self.softmax(x + torch.log(torch.sigmoid(x) + self.epsilon))


class SphericalSoftmax(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        squared = x ** 2
        return squared / torch.sum(squared, dim=self.dim, keepdim=True)


class RootLayer(nn.Module):
    def __init__(self, in_features, out_nodes, activation):
        super().__init__()
        self.num_features = in_features
        self.out_nodes = out_nodes
        self.linear = nn.Linear(in_features, out_nodes)
        self.activation = to_activation(activation, dim=1)

    def forward(self, x):
        out = self.linear(x)
        out = self.activation(out)
        return x, out

    def to_scores(self, x):
        weight = self.linear.weight.data.view(self.out_nodes, 1, self.num_features)
        bias = self.linear.bias.data.view(self.out_nodes, 1)
        scores = weight * x.view((1, 1, -1))
        return normalize_scores(scores, bias, spherical=False)


class DenseLayer(nn.Module):
    def __init__(self, in_features, curr_nodes, out_nodes, activation, fix_weight=False, prune=None):
        super().__init__()
        self.in_features = in_features
        self.curr_nodes = curr_nodes
        self.out_nodes = out_nodes
        self.prune = prune
        self.activation = to_activation(activation, dim=1)
        self.linear = nn.Linear(in_features, curr_nodes * out_nodes)

        if fix_weight:
            self.linear.weight.requires_grad = False
            self.linear.weight.fill_(0)

        if prune is not None:
            mask = torch.zeros(out_nodes, curr_nodes, dtype=torch.float32)
            mask.uniform_()
            index = torch.topk(mask, k=prune, dim=0, sorted=False)[1]
            mask.fill_(1)
            for i in range(prune):
                mask[index[i, :], torch.arange(curr_nodes)] = 0
            self.masking_value = to_masking_value(activation)
            self.mask = nn.Parameter(mask.bool(), requires_grad=False)

    def forward(self, inputs):
        x, pi_in = inputs
        t_out = self.to_transition(x)
        pi_out = (t_out * pi_in.unsqueeze(1)).sum(2)
        return x, pi_out

    def to_transition(self, x):
        t_out = self.linear(x).view(-1, self.out_nodes, self.curr_nodes)
        if self.prune is not None:
            t_out = t_out.masked_fill(self.mask, self.masking_value)
        return self.activation(t_out)

    def to_scores(self, x):
        weight = self.linear.weight.data.view(self.out_nodes, self.curr_nodes, self.in_features)
        bias = self.linear.bias.data.view(self.out_nodes, self.curr_nodes)
        scores = weight * x.view((1, 1, -1))
        return normalize_scores(scores, bias, spherical=False)


class SparseLayer(nn.Module):
    def __init__(self, in_features, curr_nodes, activation, width):
        super().__init__()
        window = 2 * width + 1
        self.window = window
        self.num_features = in_features
        self.curr_nodes = curr_nodes
        self.activation = to_activation(activation, dim=1)
        self.linear = nn.Linear(in_features, window * curr_nodes)

        mask = torch.zeros(curr_nodes, window, dtype=torch.bool)
        for i in range(curr_nodes):
            for j in range(window):
                if i < width and j < width:
                    mask[i, j] = 1
                elif curr_nodes - i - 1 < width and window - j - 1 < width:
                    mask[i, j] = 1
        mask = mask.t().view(1, window, curr_nodes)
        self.mask = nn.Parameter(mask, requires_grad=False)
        self.masking_value = to_masking_value(activation)

        self.conv = nn.ConvTranspose1d(
            in_channels=window, out_channels=1, kernel_size=window, padding=width)
        self.conv.weight.requires_grad = False
        self.conv.weight.fill_(0)
        self.conv.weight.squeeze(1).fill_diagonal_(1)
        self.conv.bias.requires_grad = False
        self.conv.bias.fill_(0)

    def to_conv_transition(self, x):
        t_conv = self.linear(x) \
            .view(-1, self.window, self.curr_nodes) \
            .masked_fill(self.mask, self.masking_value)
        return self.activation(t_conv)

    def forward(self, inputs):
        x, pi_in = inputs
        t_conv = self.to_conv_transition(x)
        pi_out = self.conv(t_conv * pi_in.unsqueeze(1)).squeeze(1)
        return x, pi_out

    def to_transition(self, x):
        window = self.window
        num_nodes = self.curr_nodes
        t_conv = self.to_conv_transition(x)
        t_full = torch.zeros(x.size(0), num_nodes, num_nodes, device=x.device)
        for i in range(num_nodes):
            tgt_from = max(i - window // 2, 0)
            tgt_to = min(i + window // 2 + 1, num_nodes)
            src_from = window // 2 if i == 0 else 0
            src_to = window // 2 + 1 if i == num_nodes - 1 else window
            t_full[:, tgt_from:tgt_to, i] = t_conv[:, src_from:src_to, i]
        return t_full

    def to_scores(self, x):
        window = self.window
        num_nodes = self.curr_nodes
        num_features = self.num_features
        w_temp = self.linear.weight.data.view(window, num_nodes, num_features)
        b_temp = self.linear.bias.data.view(window, num_nodes)
        weight = torch.zeros(num_nodes, num_nodes, num_features, device=x.device)
        bias = torch.zeros(num_nodes, num_nodes, device=x.device)
        for i in range(num_nodes):
            tgt_idx_from = max(i - window // 2, 0)
            tgt_idx_to = min(i + window // 2 + 1, num_nodes)
            src_idx_from = window // 2 if i == 0 else 0
            src_idx_to = window // 2 + 1 if i == num_nodes - 1 else window
            weight[tgt_idx_from:tgt_idx_to, i, :] = w_temp[src_idx_from:src_idx_to, i, :]
            bias[tgt_idx_from:tgt_idx_to, i] = b_temp[src_idx_from:src_idx_to, i]
        scores = weight * x.view((1, 1, -1))
        return normalize_scores(scores, bias, spherical=True)


class DTN(nn.Module):
    def __init__(self, in_features, num_classes, num_layers=8, num_units=128, width=None,
                 activation='softmax', prune=None):
        super().__init__()
        self.num_units = num_units
        self.num_layers = num_layers
        self.num_features = in_features
        self.num_classes = num_classes

        if isinstance(width, int):
            self.window = min(num_units, 2 * width + 1)
        elif width is None:
            self.window = num_units
        else:
            raise ValueError()

        layers = [RootLayer(in_features, num_units, activation)]
        for _ in range(num_layers - 1):
            if self.window < num_units:
                layer = SparseLayer(in_features, num_units, activation, width)
            else:
                layer = DenseLayer(in_features, num_units, num_units, activation, prune=prune)
            layers.append(layer)
        layers.append(DenseLayer(in_features, num_units, num_classes, activation='softmax'))
        self.layers = nn.Sequential(*layers)

    def decision_paths(self, x):
        _, pi = self.layers[0](x)
        trs_list = [pi.detach().unsqueeze(-1)]
        pi_list = [pi.detach()]
        for layer in list(self.layers[1:]):
            pi = layer((x, pi))[1]
            trs_list.append(layer.to_transition(x).detach())
            pi_list.append(pi.detach())
        return trs_list, pi_list

    def explain_decisions(self, x):
        t_list, pi_list = self.decision_paths(x)
        batch_out = []
        for x_index in range(x.size(0)):
            s_prev = torch.full((1, 1, x.size(-1)), 1 / x.size(-1), device=x.device)
            example_out = []
            for l_index in range(self.num_layers + 1):
                t_curr = t_list[l_index][x_index].unsqueeze(2)
                if l_index == 0:
                    pi_prev = torch.ones(1, 1, 1).to(x.device)
                else:
                    pi_prev = pi_list[l_index - 1][x_index]
                    pi_prev = pi_prev.view(1, pi_prev.size(0), 1)
                s_curr = self.layers[l_index].to_scores(x[x_index])
                pi_curr = pi_list[l_index][x_index]
                pi_curr = pi_curr.view(pi_curr.size(0), 1)

                s_prev = (t_curr * pi_prev * s_curr * s_prev).sum(dim=1)
                s_prev[s_prev.sum(dim=1) == 0] = 1
                s_prev = s_prev / s_prev.sum(dim=1, keepdim=True)
                example_out.append((s_prev * pi_curr).sum(dim=0))
            batch_out.append(example_out)
        out = []
        for i in range(len(batch_out[0])):
            out.append(torch.stack([e[i] for e in batch_out]))
        return pi_list, out

    def forward(self, x):
        return self.layers(x)[1]


class DTNS(DTN):
    def __init__(self, in_features, num_classes, num_layers=8, num_units=128):
        super().__init__(in_features, num_classes, num_layers, num_units, width=1, activation='spherical')
