import functools
import math
from typing import Optional
from collections import Iterable
from itertools import repeat

import torch
from torch import nn as nn
from torch.nn import functional as F

from experiments.models import stochastic_model
from experiments.models.stochastic_model import get_current_num_samples


def _ntuple(n):
    def parse(x):
        if isinstance(x, Iterable):
            return x
        return tuple(repeat(x, n))

    return parse


_single = _ntuple(1)


class _StochasticDropout(torch.nn.Module):
    consistent: bool = False

    def __init__(self, p):
        super().__init__()

        if p < 0 or p > 1:
            raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
        self.p = p
        self.mask_1_k_ = None

    def extra_repr(self):
        return "p={}".format(self.p)

    def reset_mask(self):
        self.mask_1_k_ = None

    def train(self, mode=True):
        super().train(mode)
        # Reset the mask every time we switch modes (evaluator/train)
        if not mode:
            self.reset_mask()

    def _get_sample_mask_shape(self, sample_shape):
        return sample_shape

    def _create_drop_mask(self, input_xk_, k, *, persistent):
        drop_mask_shape = [1, k] + list(self._get_sample_mask_shape(input_xk_.shape[1:]))
        # noinspection PyArgumentList
        drop_mask = torch.empty(
            drop_mask_shape, dtype=torch.bool, device=input_xk_.device, pin_memory=True if persistent else False
        ).bernoulli_(self.p)
        return drop_mask

    def forward(self, input_xk_: torch.Tensor):
        if self.p == 0.0:
            return input_xk_ * 1

        k = stochastic_model.get_current_num_samples()
        if self.consistent:
            if self.mask_1_k_ is None:
                # print('recreating mask', self)
                # Recreate mask.
                self.mask_1_k_ = self._create_drop_mask(input_xk_, k, persistent=True)

            mask_1_k_ = self.mask_1_k_ = self.mask_1_k_.to(device=input_xk_.device, non_blocking=True)
            input_x_k_ = stochastic_model.unflatten_tensor(input_xk_, k)
            output_x_k_ = input_x_k_.masked_fill(mask_1_k_, 0) / (1.0 - self.p)

            output_xk_ = stochastic_model.flatten_tensor(output_x_k_)
        else:
            mask_xk_ = self._create_drop_mask(input_xk_, input_xk_.shape[0], persistent=False)[0]
            output_xk_ = input_xk_.masked_fill(mask_xk_, 0) / (1.0 - self.p)

        return output_xk_


class StochasticDropout(_StochasticDropout):
    r"""Randomly zeroes some of the elements of the input
    tensor with probability :attr:`p` using samples from a Bernoulli
    distribution. The elements to zero are randomized on every forward call during training time.
    During eval time, a fixed mask is picked and kept until `reset_mask()` is called.
    This has proven to be an effective technique for regularization and
    preventing the co-adaptation of neurons as described in the paper
    `Improving neural networks by preventing co-adaptation of feature
    detectors`_ .
    Furthermore, the outputs are scaled by a factor of :math:`\frac{1}{1-p}` during
    training. This means that during evaluation the module simply computes an
    identity function.
    Args:
        p: probability of an element to be zeroed. Default: 0.5
        inplace: If set to ``True``, will do this operation in-place. Default: ``False``
    Shape:
        - Input: `Any`. Input can be of any shape
        - Output: `Same`. Output is of the same shape as input
    Examples::
        >>> m = nn.Dropout(p=0.2)
        >>> input = torch.randn(20, 16)
        >>> output = m(input)
    .. _Improving neural networks by preventing co-adaptation of feature
        detectors: https://arxiv.org/abs/1207.0580
    """
    pass


class StochasticDropout2d(_StochasticDropout):
    r"""Randomly zeroes whole channels of the input tensor.
    The channels to zero-out are randomized on every forward call.
    During eval time, a fixed mask is picked and kept until `reset_mask()` is called.
    Usually the input comes from :class:`nn.Conv2d` modules.
    As described in the paper
    `Efficient Object Localization Using Convolutional Networks`_ ,
    if adjacent pixels within feature maps are strongly correlated
    (as is normally the case in early convolution layers) then i.i.d. dropout
    will not regularize the activations and will otherwise just result
    in an effective learning rate decrease.
    In this case, :func:`nn.Dropout2d` will help promote independence between
    feature maps and should be used instead.
    Args:
        p (float, optional): probability of an element to be zero-ed.
        inplace (bool, optional): If set to ``True``, will do this operation
            in-place
    Shape:
        - Input: :math:`(N, C, H, W)`
        - Output: :math:`(N, C, H, W)` (same shape as input)
    Examples::
        >>> m = torch.nn.Dropout2d(p=0.2)
        >>> input = torch.randn(20, 16, 32, 32)
        >>> output = m(input)
    .. _Efficient Object Localization Using Convolutional Networks:
       http://arxiv.org/abs/1411.4280
    """

    def _get_sample_mask_shape(self, sample_shape):
        return [sample_shape[0]] + [1] * (len(sample_shape) - 1)


def inject_stochastic_dropout(callable, *args, **kwargs):
    old_dropout_class = torch.nn.Dropout
    old_dropout2d_class = torch.nn.Dropout2d
    old_dropout3d_class = torch.nn.Dropout3d
    old_alpha_dropout_class = torch.nn.AlphaDropout
    old_feature_alpha_dropout_class = torch.nn.FeatureAlphaDropout

    torch.nn.Dropout = StochasticDropout
    torch.nn.Dropout2d = StochasticDropout2d
    torch.nn.Dropout3d = None
    torch.nn.AlphaDropout = None
    torch.nn.FeatureAlphaDropout = None

    result = callable(*args, **kwargs)

    torch.nn.Dropout = old_dropout_class
    torch.nn.Dropout2d = old_dropout2d_class
    torch.nn.Dropout3d = old_dropout3d_class
    torch.nn.AlphaDropout = old_alpha_dropout_class
    torch.nn.FeatureAlphaDropout = old_feature_alpha_dropout_class

    return result


def wrap_stochastic_dropout(func):
    @functools.wraps(func)
    def wrapped_func(*args, **kwargs):
        return inject_stochastic_dropout(func, *args, **kwargs)

    return wrapped_func


class _StochasticDropoutConvNd(nn.Module):
    __constants__ = [
        "stride",
        "padding",
        "dilation",
        "groups",
        "padding_mode",
        "output_padding",
        "in_channels",
        "out_channels",
        "kernel_size",
    ]
    __annotations__ = {"bias": Optional[torch.Tensor]}

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        dilation,
        transposed,
        output_padding,
        groups,
        bias,
        padding_mode,
        dropout_rate,
    ):
        super(_StochasticDropoutConvNd, self).__init__()
        if in_channels % groups != 0:
            raise ValueError("in_channels must be divisible by groups")
        if out_channels % groups != 0:
            raise ValueError("out_channels must be divisible by groups")
        valid_padding_modes = {"zeros", "reflect", "replicate", "circular"}
        if padding_mode not in valid_padding_modes:
            raise ValueError(
                "padding_mode must be one of {}, but got padding_mode='{}'".format(valid_padding_modes, padding_mode)
            )
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.transposed = transposed
        self.output_padding = output_padding
        self.groups = groups
        self.padding_mode = padding_mode
        self._padding_repeated_twice = _repeat_tuple(self.padding, 2)
        if transposed:
            self.weight = nn.Parameter(torch.Tensor(in_channels, out_channels // groups, *kernel_size))
        else:
            self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *kernel_size))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter("bias", None)
        self.dropout_rate = dropout_rate
        # self.dropout = StochasticDropout(dropout_rate)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def extra_repr(self):
        s = "{in_channels}, {out_channels}, kernel_size={kernel_size}" ", stride={stride}"
        if self.padding != (0,) * len(self.padding):
            s += ", padding={padding}"
        if self.dilation != (1,) * len(self.dilation):
            s += ", dilation={dilation}"
        if self.output_padding != (0,) * len(self.output_padding):
            s += ", output_padding={output_padding}"
        if self.groups != 1:
            s += ", groups={groups}"
        if self.bias is None:
            s += ", bias=False"
        if self.padding_mode != "zeros":
            s += ", padding_mode={padding_mode}"
        s += ", dropout_rate={dropout_rate}"
        return s.format(**self.__dict__)

    def __setstate__(self, state):
        super(_StochasticDropoutConvNd, self).__setstate__(state)
        if not hasattr(self, "padding_mode"):
            self.padding_mode = "zeros"


class StochasticDropoutConv2d(_StochasticDropoutConvNd):
    r"""Applies a 2D convolution over an input signal composed of several input
    planes.

    In the simplest case, the output value of the layer with input size
    :math:`(N, C_{\text{in}}, H, W)` and output :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})`
    can be precisely described as:

    .. math::
        \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
        \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)


    where :math:`\star` is the valid 2D `cross-correlation`_ operator,
    :math:`N` is a batch size, :math:`C` denotes a number of channels,
    :math:`H` is a height of input planes in pixels, and :math:`W` is
    width in pixels.

    * :attr:`stride` controls the stride for the cross-correlation, a single
      number or a tuple.

    * :attr:`padding` controls the amount of implicit zero-paddings on both
      sides for :attr:`padding` number of points for each dimension.

    * :attr:`dilation` controls the spacing between the kernel points; also
      known as the à trous algorithm. It is harder to describe, but this `link`_
      has a nice visualization of what :attr:`dilation` does.

    * :attr:`groups` controls the connections between inputs and outputs.
      :attr:`in_channels` and :attr:`out_channels` must both be divisible by
      :attr:`groups`. For example,

        * At groups=1, all inputs are convolved to all outputs.
        * At groups=2, the operation becomes equivalent to having two conv
          layers side by side, each seeing half the input channels,
          and producing half the output channels, and both subsequently
          concatenated.
        * At groups= :attr:`in_channels`, each input channel is convolved with
          its own set of filters, of size:
          :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`.

    The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:

        - a single ``int`` -- in which case the same value is used for the height and width dimension
        - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
          and the second `int` for the width dimension

    .. note::

         Depending of the size of your kernel, several (of the last)
         columns of the input might be lost, because it is a valid `cross-correlation`_,
         and not a full `cross-correlation`_.
         It is up to the user to add proper padding.

    .. note::

        When `groups == in_channels` and `out_channels == K * in_channels`,
        where `K` is a positive integer, this operation is also termed in
        literature as depthwise convolution.

        In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`,
        a depthwise convolution with a depthwise multiplier `K`, can be constructed by arguments
        :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.

    .. include:: cudnn_deterministic.rst

    Args:
        in_channels (int): Number of channels in the input image
        out_channels (int): Number of channels produced by the convolution
        kernel_size (int or tuple): Size of the convolving kernel
        stride (int or tuple, optional): Stride of the convolution. Default: 1
        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
        padding_mode (string, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default:
        ``'zeros'``
        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
        bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``

    Shape:
        - Input: :math:`(N, C_{in}, H_{in}, W_{in})`
        - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where

          .. math::
              H_{out} = \left\lfloor\frac{H_{in}  + 2 \times \text{padding}[0] - \text{dilation}[0]
                        \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor

          .. math::
              W_{out} = \left\lfloor\frac{W_{in}  + 2 \times \text{padding}[1] - \text{dilation}[1]
                        \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor

    Attributes:
        weight (Tensor): the learnable weights of the module of shape
                         :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
                         :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`.
                         The values of these weights are sampled from
                         :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
                         :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
        bias (Tensor):   the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``,
                         then the values of these weights are
                         sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
                         :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`

    Examples::

        >>> # With square kernels and equal stride
        >>> m = nn.Conv2d(16, 33, 3, stride=2)
        >>> # non-square kernels and unequal stride and with padding
        >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
        >>> # non-square kernels and unequal stride and with padding and dilation
        >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
        >>> input = torch.randn(20, 16, 50, 100)
        >>> output = m(input)

    .. _cross-correlation:
        https://en.wikipedia.org/wiki/Cross-correlation

    .. _link:
        https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        bias=True,
        padding_mode="zeros",
        dropout_rate=0.2,
    ):
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(StochasticDropoutConv2d, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            False,
            _pair(0),
            groups,
            bias,
            padding_mode,
            dropout_rate=dropout_rate,
        )
        self.dropout = StochasticDropout(dropout_rate)

    def _conv_forward(self, input, weight, outer_groups):
        if self.padding_mode != "zeros":
            return F.conv2d(
                F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
                weight,
                self.bias,
                self.stride,
                _pair(0),
                self.dilation,
                self.groups * outer_groups,
            )
        return F.conv2d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups * outer_groups)

    def forward(self, input):
        num_masks = get_current_num_samples()
        input_B_KC_ = input.view(-1, num_masks, *input.shape[1:]).flatten(1, 2)
        expanded_weights = self.weight.unsqueeze(0).expand(num_masks, *self.weight.shape)
        dropped_weights_KC_ = self.dropout(expanded_weights).flatten(0, 1)
        output_B_KC_ = self._conv_forward(input_B_KC_, dropped_weights_KC_, num_masks)
        output_BK_ = output_B_KC_.view(input.shape[0], -1, *output_B_KC_.shape[2:])
        return output_BK_

        # num_masks = input.shape[0]
        # input_B_KC_ = input.unsqueeze(0).flatten(1,2)
        # expanded_weights = self.weight.unsqueeze(0).expand(num_masks, *self.weight.shape)
        # dropped_weights_KC_ = self.dropout(expanded_weights).flatten(0,1)
        # output_B_KC_ = self._conv_forward(input_B_KC_, dropped_weights_KC_, num_masks)
        # output_BK_ = output_B_KC_.view(input.shape[0], -1, *output_B_KC_.shape[2:])
        # return output_BK_

        # num_samples = get_current_num_samples()
        # expanded_weights_K_C_ = self.weight.unsqueeze(0).expand(num_samples, *self.weight.shape)
        # outputs = []
        # for chunk in input.split(num_samples):
        #     dropped_weights_KC_ = self.dropout(expanded_weights_K_C_).flatten(0,1)
        #     chunk_1_KC_ = chunk.unsqueeze(0).flatten(1,2)
        #     output_1_KC_ = self._conv_forward(chunk_1_KC_, dropped_weights_KC_, num_samples)
        #     output_K_C_ = output_1_KC_.view(num_samples, -1, *output_1_KC_.shape[2:])
        #     outputs.append(output_K_C_)
        # outputs = torch.cat(outputs)
        # return outputs

        # Dropout after the fact.
        # return self.dropout(self._conv_forward(input, self.weight, 1))


def stochastic_dropout_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, dropout_rate=0.2):
    num_masks = get_current_num_samples()
    input_B_KC_ = input.view(-1, num_masks, *input.shape[1:]).flatten(1, 2)
    expanded_weights = weight.unsqueeze(0).expand(num_masks, *weight.shape)
    dropped_weights_KC_ = StochasticDropout(dropout_rate)(expanded_weights).flatten(0, 1)
    output_B_KC_ = F.conv2d(input_B_KC_, dropped_weights_KC_, bias, stride, padding, dilation, groups * num_masks)
    output_BK_ = output_B_KC_.view(input.shape[0], -1, *output_B_KC_.shape[2:])
    return output_BK_


def _repeat_tuple(t, n):
    r"""Repeat each element of `t` for `n` times.

    This can be used to translate padding arg used by Conv and Pooling modules
    to the ones used by `F.pad`.
    """
    return tuple(x for x in t for _ in range(n))


_pair = _ntuple(2)
