"""Implements a malicious block that can be inserted at the front on normal models to break them."""
from statistics import NormalDist
import torch

from scipy.stats import norm


class ImprintBlock(torch.nn.Module):
    def __init__(self, image_size, num_bins, alpha=0):
        """
        TODO: Get rid of annoying alpha argument
        image_size is the size of the input images
        num_bins is how many "paths" to include in the model
        """
        super().__init__()
        self.image_size = image_size
        self.num_bins = num_bins
        self.linear0 = torch.nn.Linear(image_size, num_bins)
        self.linear2 = torch.nn.Linear(num_bins, image_size)
        self.bins = self._get_bins()
        with torch.no_grad():
            self.linear0.weight.data[:, :] = self._make_average_layer()
            self.linear0.bias.data[:] = self._make_biases()
            self.linear2.weight.data = torch.ones_like(self.linear2.weight.data)
            self.linear2.bias.data -= torch.as_tensor(self.bins).mean()
            # torch.nn.init.orthogonal_(self.linear2.weight)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.linear0(x)
        x = self.relu(x)
        # output = x_in + x.mean(dim=1, keepdim=True)
        # output = torch.cat([x, x_in[:, self.num_bins:]], dim=1)
        # s = torch.softmax(x, dim=1)[:, :, None]
        # output = (x_in[:, None, :] * s).sum(dim=1)
        output = self.linear2(x)
        return output

    def _get_bins(self):
        left_bins = []
        bins = []
        mass_per_bin = 1 / (self.num_bins)
        for i in range(self.num_bins):
            bins.append(norm.ppf(i * mass_per_bin))
        bins[0] = -10  # -Inf is not great here, but NormalDist(mu=0, sigma=1).cdf(10) approx 1
        # bins[-1] = 10 # this is a boring bin
        return bins

    def _make_average_layer(self):
        new_data = 1 / self.linear0.weight.data.shape[-1] * torch.ones_like(self.linear0.weight.data)
        return new_data

    def _make_biases(self):
        new_biases = torch.zeros_like(self.linear0.bias.data)
        for i in range(new_biases.shape[0]):
            new_biases[i] = -self.bins[i]
        return new_biases


class DifferentialBlock(torch.nn.Module):
    """Recover data in v-notation instead of u, i.e. from differences in gradients instead of 1-hot."""

    def __init__(self, input_length, num_bins, alpha=None):
        super().__init__()
        self.linear = torch.nn.Linear(input_length, num_bins)
        self.nonlin = torch.nn.ReLU()

        # self.scaler = torch.nn.Linear(num_bins, num_bins, bias=False)  # sanity check
        # the linear_out layer is just to plug-and-play in any location. This is not strictly necessary.
        # You could just as well just connect from num_bins to the next layer
        self.linear_out = torch.nn.Linear(num_bins, input_length)
        self.bins, self.bin_sizes = self.get_bins_by_mass(num_bins)
        # Initialize:
        self.reset_weights()

    def reset_weights(self):
        with torch.no_grad():
            setup = dict(device=self.linear.weight.device, dtype=self.linear.weight.dtype)

            self.linear.weight.data = torch.ones_like(self.linear.weight)
            self.linear.weight.data /= self.linear.in_features / torch.as_tensor(self.bin_sizes, **setup)[:, None]
            self.linear.bias.data = -torch.as_tensor(self.bins, **setup)

            torch.nn.init.orthogonal_(self.linear_out.weight, gain=1.0)

    def get_bins_by_mass(self, num_bins, mu=0, sigma=1):
        bins = []
        mass = 0
        for path in range(num_bins + 1):
            mass += 1 / (num_bins + 2)
            bins += [NormalDist(mu=mu, sigma=sigma).inv_cdf(mass)]
        bin_sizes = [bins[i + 1] - bins[i] for i in range(len(bins) - 1)]
        # bins = torch.linspace(-1, 1, num_bins)
        return bins, bin_sizes

    def forward(self, x):
        x = self.nonlin(self.linear(x))
        x = self.linear_out(x)
        return x


class SparseImprintBlock(torch.nn.Module):
    def __init__(self, image_size, num_bins, alpha=0):
        """
        image_size is the size of the input images
        num_bins is how many "paths" to include in the model

        This block uses a single hardtanh for simplicity of presentation.
        This not at all necessary and can be replaced with
        another simple linear+relu layer which replicates the same computation.
        """
        super().__init__()
        self.image_size = image_size
        self.num_bins = num_bins
        self.linear0 = torch.nn.Linear(image_size, num_bins)
        self.linear2 = torch.nn.Linear(num_bins, image_size)

        self.bins, self.bin_sizes = self._get_bins(num_bins)
        with torch.no_grad():
            self.linear0.weight.data[:, :] = self._make_scaled_average_layer()
            self.linear0.bias.data[:] = self._make_biases()
            self.linear2.weight.data = torch.ones_like(self.linear2.weight.data) / image_size / num_bins
        self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=1)

    def forward(self, x):
        x = self.linear0(x)
        x = (self.hardtanh(x) + 1) / 2
        output = self.linear2(x)
        return output

    def _get_bins(self, num_bins, mu=0, sigma=1):
        bins = []
        mass = 0
        for path in range(num_bins + 1):
            mass += 1 / (num_bins + 2)
            bins += [NormalDist(mu=mu, sigma=sigma).inv_cdf(mass)]
        bin_sizes = [bins[i + 1] - bins[i] for i in range(len(bins) - 1)]
        return bins, bin_sizes


    def _make_scaled_average_layer(self):
        new_data = 1 / self.linear0.weight.data.shape[-1] * torch.ones_like(self.linear0.weight.data)
        for i, row in enumerate(new_data):
            row /= torch.as_tensor(self.bin_sizes[i], device=new_data.device)
        return new_data

    def _make_biases(self):
        new_biases = torch.zeros_like(self.linear0.bias.data)
        for i, (bin_val, bin_width) in enumerate(zip(self.bins[1:-1], self.bin_sizes[1:-1])):
            new_biases[i + 1] = -bin_val / bin_width
        return new_biases


class EquispacedImprintBlock(torch.nn.Module):
    """The old implementation."""

    def __init__(self, image_size, num_bins, alpha=0.375):
        """
        image_size is the size of the input images
        num_bins is how many "paths" to include in the model
        """
        super().__init__()
        self.image_size = image_size
        self.num_bins = num_bins
        self.linear0 = torch.nn.Linear(image_size, num_bins)
        self.linear1 = torch.nn.Linear(num_bins, image_size)

        self.alpha = alpha
        self.bins, self.bin_val = self._get_bins()
        with torch.no_grad():
            self.linear0.weight.data[:, :] = self._make_average_layer()
            self.linear0.bias.data[:] = self._make_biases()
        self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=self.bin_val)

    def forward(self, x):
        x = self.linear0(x)
        x = (self.hardtanh(x) + 1) / 2
        output = self.linear1(x)
        return output

    def _get_bins(self):
        order_stats = [self._get_order_stats(r + 1, self.num_bins) for r in range(self.num_bins)]
        diffs = [order_stats[i] - order_stats[i + 1] for i in range(len(order_stats) - 1)]
        bin_val = -sum(diffs) / len(diffs)
        left_bins = [-i * bin_val for i in range(self.num_bins // 2 + 1)]
        right_bins = [i * bin_val for i in range(1, self.num_bins // 2)]
        left_bins.reverse()
        bins = left_bins + right_bins
        return bins, bin_val

    def _get_order_stats(self, r, n):
        r"""
        r Order statistics can be computed as follows:
        E(r:n) = \mu + \Phi^{-1}\left( \frac{r-a}{n-2a+1} \sigma \right)
        where a = 0.375
        """
        return 0 + norm.ppf((r - self.alpha) / (n - 2 * self.alpha + 1)) * 1

    def _make_average_layer(self):
        new_data = 1 / self.linear0.weight.data.shape[-1] * torch.ones_like(self.linear0.weight.data)
        return new_data

    def _make_biases(self):
        new_biases = torch.zeros_like(self.linear0.bias.data)
        for i in range(len(new_biases)):
            new_biases[i] = self.bins[i]

        return new_biases
