"""
By Silvia Pintea, 2021. Adapted for FlexConv by _anonimyzed_, with
permission.

Silvia L Pintea, Nergis Tomen, Stanley F Goes, Marco Loog, and Jan C van Gemert.
Resolution learning in deep convolutional networks using scale-space theory.
arXiv preprint arXiv:2106.03412,2021.
"""

# Import general dependencies
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn as nn
from torchvision import transforms
from torch.autograd import Function
from torch.distributions import normal
from srf.nn.gaussian_basis_filters import *

import torch.nn.functional as F
import time


def entropy(p, dim=-1, keepdim=False):
    p = torch.abs(p)
    p = p / p.sum(dim=dim, keepdim=True)
    return -torch.where(p > 0, p * p.log(), p.new([0.0])).sum(dim=dim, keepdim=keepdim)


class Srf_layer(nn.Module):
    def __init__(
        self,
        inC,
        outC,
        num_scales,
        init_k,
        init_order,
        init_scale,
        learn_sigma,
        use_cuda,
    ):

        super(Srf_layer, self).__init__()
        self.init_k = init_k
        self.init_order = init_order
        self.init_scale = init_scale
        self.in_channels = inC
        self.out_channels = outC

        # ---------------
        F = int((self.init_order + 1) * (self.init_order + 2) / 2)

        """ Create weight variables. """
        self.use_cuda = use_cuda
        self.device = torch.device("cuda" if use_cuda else "cpu")
        self.scales = nn.ParameterList([])

        self.alphas = torch.nn.Parameter(
            torch.zeros([F, inC, outC], device=self.device), requires_grad=True
        )

        scales_init = np.random.normal(loc=0.0, scale=2.0, size=num_scales)
        basis_tensor = None
        for i in range(0, num_scales):
            if learn_sigma:
                self.scales.append(
                    torch.nn.Parameter(
                        torch.tensor(scales_init[i], device=self.device),
                        requires_grad=True,
                    )
                )
            else:
                self.scales.append(
                    torch.nn.Parameter(
                        torch.tensor(
                            self.init_scale, device=self.device, dtype=torch.float32
                        ),
                        requires_grad=False,
                    )
                )

            # Fancy init of alphas
            _, a_basis_tensor = gaussian_basis_filters(
                order=self.init_order,
                sigma=2.0 ** self.scales[i],
                k=self.init_k,
                alphas=None,
                use_cuda=self.use_cuda,
            )
            if basis_tensor == None:
                basis_tensor = a_basis_tensor
            else:
                basis_tensor = basis_tensor + a_basis_tensor
        basis_tensor = basis_tensor / float(num_scales)

        for j in range(0, basis_tensor.shape[0]):
            basis_norm2 = torch.norm(basis_tensor[j, :, :], p=2)
            var = 2.0 / (F * inC * basis_norm2)
            torch.nn.init.normal_(
                self.alphas[i, :, :],
                mean=0.0,
                std=torch.sqrt(var).detach().cpu().float(),
            )
        self.extra_reg = 0

    def forward_no_input(self):
        t = time.time()
        filters = []
        for i in range(0, len(self.scales)):
            sigma = 2.0 ** self.scales[i]

            one_filter, _ = gaussian_basis_filters(
                order=self.init_order,
                sigma=sigma,
                k=self.init_k,
                alphas=self.alphas,
                use_cuda=self.use_cuda,
            )
            filters.append(one_filter)

        elapsed = time.time()
        print("Forward pass timing ", elapsed - t)
        return filters

    def forward(self, indata):
        self.filters = []
        scales_np = []
        for s in self.scales:
            scales_np.append(s.detach())
        indata = safe_sample(indata, 2 ** np.max(scales_np))
        conv_output = torch.zeros(
            [indata.shape[0], self.out_channels, indata.shape[2], indata.shape[3]],
            device=self.device,
        )  # no subsample

        for i in range(0, len(self.scales)):
            sigma = 2.0 ** self.scales[i]
            one_filter, _ = gaussian_basis_filters(
                order=self.init_order,
                sigma=sigma,
                k=self.init_k,
                alphas=self.alphas,
                use_cuda=self.use_cuda,
            )
            out = F.conv2d(
                input=indata,
                weight=one_filter,  # KCHW
                bias=None,
                stride=1,
                padding=int((one_filter.shape[2] - 1) / 2),
            )
            conv_output += out

        # Extra regularization
        alpha_reshaped = self.alphas.view(self.alphas.shape[0], -1)
        self.extra_reg = -(entropy(alpha_reshaped, dim=0, keepdim=False).sum()) / float(
            alpha_reshaped.shape[0] * alpha_reshaped.shape[1]
        )
        return conv_output

    def listParams(self):
        params = list(self.parameters())
        total_params = 0

        for i in range(0, len(params)):
            total_params = total_params + np.prod(list(params[i].size()))

        print("Total parameters: ", total_params)


class Srf_layer_shared_alpha(Srf_layer):
    def __init__(
        self, inC, outC, init_k, init_order, init_scale, learn_sigma, use_cuda
    ):
        super(Srf_layer, self).__init__()

        self.init_k = init_k
        self.init_order = init_order
        self.init_scale = init_scale
        self.inC = inC
        self.outC = outC

        # ---------------
        self.F = int((self.init_order + 1) * (self.init_order + 2) / 2)

        """ Create weight variables. """
        self.use_cuda = use_cuda
        self.device = torch.device("cuda" if use_cuda else "cpu")
        self.pre_alpha = nn.Conv2d(inC, inC * outC * self.F, 1)
        self.soft = torch.nn.Softmax(dim=1)

        if learn_sigma:
            self.scales = torch.nn.Parameter(
                torch.tensor(
                    np.full((1), self.init_scale),
                    device=self.device,
                    dtype=torch.float32,
                ),
                requires_grad=True,
            )
        else:
            self.scales = torch.nn.Parameter(
                torch.tensor(
                    np.full((1), self.init_scale),
                    device=self.device,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
        self.extra_reg = 0
        self.filters = None
        self.alpha = None

    def forward(self, data):
        self.alpha = data.mean(dim=3, keepdim=True).mean(dim=2, keepdim=True)
        self.alpha = self.pre_alpha(self.alpha)
        # N, C*K*F, 1, 1
        self.alpha = torch.reshape(
            self.alpha, (self.alpha.shape[0], self.F, self.inC, self.outC)
        )
        # N, F, C, K
        self.alpha = self.soft(self.alpha)

        # Define sigma from the scale
        self.sigma = 2.0 ** self.scales
        self.filters, _ = gaussian_basis_filters_shared_alpha(
            order=self.init_order,
            sigma=self.sigma,
            k=self.init_k,
            alphas=self.alpha,
            use_cuda=self.use_cuda,
        )
        data = safe_sample(data, self.sigma)
        # N C H W --> 1 N*C H W
        data_shape = data.shape
        data = torch.reshape(
            data, (1, data_shape[0] * data_shape[1], data_shape[2], data_shape[3])
        )
        # N K C h w --> N*K C h w
        self.filters = torch.reshape(
            self.filters,
            (
                self.filters.shape[0] * self.filters.shape[1],
                self.filters.shape[2],
                self.filters.shape[3],
                self.filters.shape[4],
            ),
        )

        final_conv = F.conv2d(
            input=data,  # 1 NC H W
            weight=self.filters,  # NK C H W
            bias=None,
            stride=1,
            padding=int(self.filters.shape[2] / 2),
            groups=data_shape[0],
        )

        # 1 N*K H W --> N K H W
        final_conv = torch.reshape(
            final_conv,
            (data_shape[0], self.outC, final_conv.shape[2], final_conv.shape[3]),
        )

        self.extra_reg = 0
        return final_conv


class Srf_layer_shared(Srf_layer):
    def __init__(
        self,
        in_channels,
        out_channels,
        init_k=None,
        init_order=None,
        init_scale=None,
        learn_sigma=True,
        use_cuda=True,
        groups=1,
        scale_sigma=0.0,
    ):
        super(Srf_layer, self).__init__()

        self.init_k = init_k
        self.init_order = init_order
        self.init_scale = init_scale
        self.inC = in_channels

        assert out_channels % groups == 0
        self.outC = out_channels
        self.groups = groups
        self.scale_sigma = scale_sigma

        # ---------------
        F = int((self.init_order + 1) * (self.init_order + 2) / 2)

        """ Create weight variables. """
        self.use_cuda = use_cuda
        self.device = torch.device("cuda" if use_cuda else "cpu")
        self.alphas = torch.nn.Parameter(
            torch.zeros(
                [F, int(in_channels / groups), out_channels], device=self.device
            ),
            requires_grad=True,
        )

        torch.nn.init.normal_(self.alphas, mean=0.0, std=1)
        if learn_sigma:
            # self.scales = torch.nn.Parameter(torch.zeros([1], \
            #                device=self.device), requires_grad=True)
            # torch.nn.init.normal_(self.scales, mean=0.0, std=self.init_scale)
            self.scales = torch.nn.Parameter(
                torch.tensor(
                    np.full((1), self.init_scale),
                    device=self.device,
                    dtype=torch.float32,
                ),
                requires_grad=True,
            )
        else:
            self.scales = torch.nn.Parameter(
                torch.tensor(
                    np.full((1), self.init_scale),
                    device=self.device,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )

        # Fancy init of alphas
        """
        _, basis_tensor = gaussian_basis_filters(
                                            order=self.init_order, \
                                            sigma=2.0**self.scales, \
                                            k=self.init_k, \
                                            alphas=None,\
                                            use_cuda=self.use_cuda)
        for i in range(0, basis_tensor.shape[0]):
            basis_norm2 = torch.norm(basis_tensor[i,:,:], p=2)
            var = 2.0/(F*inC*basis_norm2)
            torch.nn.init.normal_(self.alphas[i,:,:], mean=0.0,
                    std=torch.sqrt(var).detach().cpu().float())
        """
        self.extra_reg = 0

    def forward_no_input(self):
        # Define sigma from the scale
        self.sigma = 2.0 ** self.scales

        # DEBUG(rjbruin)
        if self.scale_sigma > 0.0:
            sigma = self.sigma * self.scale_sigma
        else:
            sigma = self.sigma

        self.filtersize = torch.ceil(self.init_k * sigma[0] + 0.5)

        try:
            self.x = torch.arange(
                start=-self.filtersize.detach().cpu().float(),
                end=self.filtersize.detach().cpu().float() + 1,
                step=1,
            )
        except:
            print("Sigma value is off:", sigma)

        self.hermite = self.x
        (
            self.filters,
            self.basis,
            self.gauss,
            self.hermite,
        ) = gaussian_basis_filters_shared(
            x=self.x,
            hermite=self.hermite,
            order=self.init_order,
            sigma=sigma,
            alphas=self.alphas,
            use_cuda=self.use_cuda,
        )
        return self.filters

    # def forward(self, data):
    #     # Define sigma from the scale
    #     # NOTE(rjbruin): switch these lines for CUDA/CPU
    #     # self.sigma = torch.pow(torch.tensor([2.0]).cuda(), self.scales)
    #     self.sigma = torch.pow(torch.tensor([2.0]), self.scales)

    #     # DEBUG(rjbruin)
    #     if self.scale_sigma > 0.0:
    #         sigma = self.sigma * self.scale_sigma
    #     else:
    #         sigma = self.sigma

    #     self.filtersize = torch.ceil(self.init_k * sigma[0] + 0.5)
    #     try:

    #         self.x = torch.arange(
    #             start=-self.filtersize.detach().cpu().float(),
    #             end=self.filtersize.detach().cpu().float() + 1,
    #             step=1,
    #         )
    #     except:
    #         print("Sigma value is off:", sigma, "filter size:", self.filtersize)

    #     # self.nalphas = self.alphas / torch.norm(self.alphas, p=2)
    #     self.hermite = self.x

    #     (
    #         self.filters,
    #         self.basis,
    #         self.gauss,
    #         self.hermite,
    #     ) = gaussian_basis_filters_shared(
    #         x=self.x,
    #         hermite=self.hermite,
    #         order=self.init_order,
    #         sigma=sigma,
    #         alphas=self.alphas,
    #         use_cuda=self.use_cuda,
    #     )
    #     # data = safe_sample(data, sigma)

    #     if self.groups == None:
    #         self.final_conv = F.conv2d(
    #             input=data,  # NCHW
    #             weight=self.filters,  # KCHW
    #             bias=None,
    #             stride=1,
    #             padding=int(self.filters.shape[2] / 2),
    #         )
    #     else:
    #         self.final_conv = F.conv2d(
    #             input=data,  # NCHW
    #             weight=self.filters,  # KCHW
    #             bias=None,
    #             stride=1,
    #             padding=int(self.filters.shape[2] / 2),
    #             groups=self.groups,
    #         )

    #     # Compute alpha entropy: [F, inC, outC]

    #     self.extra_reg = torch.norm(self.alphas, p=2) + sigma

    #     """
    #     alpha_reshaped = self.alphas.view(self.alphas.shape[0], -1)
    #     self.extra_reg = -(entropy(alpha_reshaped, dim=0,\
    #                             keepdim=False).sum())\
    #                             /float(alpha_reshaped.shape[0]*
    #                                 alpha_reshaped.shape[1])
    #     """
    #     return self.final_conv

    def forward(self, data):
        # Define sigma from the scale
        # NOTE(rjbruin): switch these lines for CUDA/CPU
        # self.sigma = torch.pow(torch.tensor([2.0]).cuda(), self.scales)
        self.sigma = torch.pow(torch.tensor([2.0]), self.scales)

        # DEBUG(rjbruin)
        if self.scale_sigma > 0.0:
            sigma = self.sigma * self.scale_sigma
        else:
            sigma = self.sigma

        # self.filtersize = torch.ceil(self.init_k * sigma[0] + 0.5)00
        self.filtersize = torch.tensor((data.shape[2] - 1) / 2)
        # try:

        self.x = torch.arange(
            start=-self.filtersize.detach().cpu().float(),
            end=self.filtersize.detach().cpu().float() + 1,
            step=1,
        )
        # except:
        #     print("Sigma value is off:", sigma, "filter size:", self.filtersize)

        # self.nalphas = self.alphas / torch.norm(self.alphas, p=2)
        self.hermite = self.x

        (
            self.filters,
            self.basis,
            self.gauss,
            self.hermite,
        ) = gaussian_basis_filters_shared(
            x=self.x,
            hermite=self.hermite,
            order=self.init_order,
            sigma=sigma,
            alphas=self.alphas,
            use_cuda=self.use_cuda,
        )
        # data = safe_sample(data, sigma)

        # if self.groups == None:
        #     self.final_conv = F.conv2d(
        #         input=data,  # NCHW
        #         weight=self.filters,  # KCHW
        #         bias=None,
        #         stride=1,
        #         padding=int(self.filters.shape[2] / 2),
        #     )
        # else:
        #     self.final_conv = F.conv2d(
        #         input=data,  # NCHW
        #         weight=self.filters,  # KCHW
        #         bias=None,
        #         stride=1,
        #         padding=int(self.filters.shape[2] / 2),
        #         groups=self.groups,
        #     )

        # Compute alpha entropy: [F, inC, outC]

        self.extra_reg = torch.norm(self.alphas, p=2) + sigma

        """
        alpha_reshaped = self.alphas.view(self.alphas.shape[0], -1)
        self.extra_reg = -(entropy(alpha_reshaped, dim=0,\
                                keepdim=False).sum())\
                                /float(alpha_reshaped.shape[0]*
                                    alpha_reshaped.shape[1])
        """
        # return self.final_conv

        return self.filters[0,0]

    def num_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


def safe_sample(current, sigma, r=4.0):
    update_val = max(1.0, torch.div(2 ** sigma, r))
    shape = current.shape
    shape_out = max(
        [1, 1], [int(float(shape[2]) / update_val), int(float(shape[3]) / update_val)]
    )
    current_out = F.interpolate(current, shape_out)
    return current_out
