import polar_util as util

import numpy as np
import torch
from torch import nn

import math
import warnings
from enum import Enum





def _sparse_rbf_polar_to_cartesian(params):
    num_radial, num_angular = util._polar_grid_size(params)

    polar_grid, polar_weights = util._polar_grid(num_radial,
        num_angular, params["quadrature"])
    cartesian_grid = util._cartesian_grid(params["L"])

    L = params["L"]
    _b = params["b"] * 2 / L

    indices = []
    vals = []

    for j in range(num_radial):
        for k in range(num_angular):
            row = j * num_angular + k

            center = np.round(polar_grid[j, k] * L / 2) + L // 2
            idx_start = np.maximum(center - 6 * params["b"], 0).astype("int64")
            idx_end = np.minimum(center + 6 * params["b"], L - 1).astype("int64") + 1

            cols = [n * L + m for n in range(idx_start[0], idx_end[0]) for m in range(idx_start[1], idx_end[1])]

            subgrid = cartesian_grid[idx_start[0]:idx_end[0],
                                     idx_start[1]:idx_end[1]]

            coefs_row = np.exp(-np.sum((subgrid - polar_grid[j, k]) ** 2, axis=-1) / (2 * _b ** 2))
            coefs_row = coefs_row.ravel().astype("float32")

            for col in cols:
                indices.append((row, col))

            for coef in coefs_row:
                vals.append(coef)

    indices = list(zip(*indices))

    return indices, vals


class CartesianToRbfPolar(nn.Module):
    def __init__(self, params=dict()):
        super(CartesianToRbfPolar, self).__init__()

        params = util.default_rbf_params(params)

        assert params["normalize"]

        num_radial, num_angular = util._polar_grid_size(params)

        L = params["L"]
        _b = params["b"] * 2 / L

        _, polar_weights = util._polar_grid(num_radial,
                num_angular, params["quadrature"])


        indices, vals = _sparse_rbf_polar_to_cartesian(params)
        indices = [indices[1], indices[0]]
        coefs = torch.sparse_coo_tensor(indices, vals, (L ** 2, num_radial * num_angular))

        with warnings.catch_warnings():
            warnings.filterwarnings(action="ignore",
                                    category=UserWarning,
                                    message=r"Sparse CSR tensor support")
            coefs = coefs.to_sparse_csr()

        self.params = params

        self.L = L
        self._b = _b

        self.num_radial = num_radial
        self.num_angular = num_angular

        polar_weights = torch.from_numpy(polar_weights.astype("float32"))

        self.register_buffer("polar_weights", polar_weights)
        self.register_buffer("coefs", coefs)

    def forward(self, x):
        num_images, num_channels, *im_shape = x.shape

        L = self.L
        _b = self._b

        assert im_shape[0] == im_shape[1]
        assert im_shape[0] == L

        x = torch.reshape(x, (num_images, num_channels, L ** 2))
        y = x @ self.coefs
        y = torch.reshape(y, (num_images, num_channels, self.num_radial, self.num_angular))

        y *= torch.sqrt(self.polar_weights[np.newaxis, np.newaxis])
        y /= np.sqrt(np.pi * self._b ** 2)

        y /= np.sqrt((L / 2) ** 2 * 2 * np.pi * 2 * _b ** 2)

        return y


class RbfPolarToCartesian(nn.Module):
    def __init__(self, params=dict()):
        super(RbfPolarToCartesian, self).__init__()

        params = util.default_rbf_params(params)

        assert params["normalize"]

        num_radial, num_angular = util._polar_grid_size(params)

        L = params["L"]
        _b = params["b"] * 2 / L

        _, polar_weights = util._polar_grid(num_radial,
                num_angular, params["quadrature"])
        cartesian_grid = util._cartesian_grid(L)

        indices, vals = _sparse_rbf_polar_to_cartesian(params)
        coefs = torch.sparse_coo_tensor(indices, vals, (num_radial *
            num_angular, L ** 2))

        with warnings.catch_warnings():
            warnings.filterwarnings(action="ignore",
                                    category=UserWarning,
                                    message=r"Sparse CSR tensor support")
            coefs = coefs.to_sparse_csr()

        kernel = np.exp(-np.sum(cartesian_grid ** 2, axis=-1)
                        / (2 * _b) ** 2)
        kernel = np.fft.ifftshift(kernel, axes=(-2, -1))
        kernel_f = np.fft.fft2(kernel, axes=(-2, -1))

        self.L = L
        self._b = _b

        self.params = params

        self.num_radial = num_radial
        self.num_angular = num_angular

        polar_weights = torch.from_numpy(polar_weights.astype("float32"))
        kernel_f = torch.from_numpy(kernel_f)

        self.register_buffer("polar_weights", polar_weights)
        self.register_buffer("kernel_f", kernel_f)
        self.register_buffer("coefs", coefs)

    def forward(self, y):
        num_images, num_channels, *polar_shape = y.shape

        L = self.L
        _b = self._b

        assert polar_shape[0] == self.num_radial
        assert polar_shape[1] == self.num_angular

        y = torch.sqrt(self.polar_weights) * y

        y = torch.reshape(y, (num_images, num_channels, self.num_radial * self.num_angular))
        x = y @ self.coefs
        x = torch.reshape(x, (num_images, num_channels, L, L))

        x /= np.sqrt(np.pi * _b ** 2)

        x *= np.sqrt((L / 2) ** 2 * 2 * np.pi * 2 * _b ** 2)

        x_f = torch.fft.fft2(x, dim=(-2, -1))
        x_f /= self.kernel_f
        x = torch.fft.ifft2(x_f, dim=(-2, -1)).real

        return x


class RbfPolarToOneDim(nn.Module):
    def __init__(self, params=dict()):
        super(RbfPolarToOneDim, self).__init__()

        params = util.default_rbf_params(params)

        num_radial, num_angular = util._polar_grid_size(params)

        self.num_radial = num_radial
        self.num_angular = num_angular

    def forward(self, x):
        num_images, num_channels, *polar_shape = x.shape

        assert polar_shape[0] == self.num_radial
        assert polar_shape[1] == self.num_angular

        new_shape = (num_images, num_channels * self.num_radial,
                     self.num_angular)

        y = torch.reshape(x, new_shape)

        return y


class OneDimToRbfPolar(nn.Module):
    def __init__(self, params=dict()):
        super(OneDimToRbfPolar, self).__init__()

        params = util.default_rbf_params(params)

        num_radial, num_angular = util._polar_grid_size(params)

        self.num_radial = num_radial
        self.num_angular = num_angular

    def forward(self, y):
        num_images, num_channels_radial, num_angular = y.shape

        assert num_channels_radial % self.num_radial == 0
        assert num_angular == self.num_angular

        num_channels = num_channels_radial // self.num_radial

        new_shape = (num_images, num_channels, self.num_radial, num_angular)

        x = torch.reshape(y, new_shape)

        return x


class AngularConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, bias,
            polar_params=dict(), symmetric=False, device=None, dtype=None):
        super(AngularConv, self).__init__()
        factory_kwargs = {"device": device, "dtype": dtype}

        polar_params = util.default_rbf_params(polar_params)
        num_radial, num_angular = util._polar_grid_size(polar_params)

        self.num_radial = num_radial
        self.num_angular = num_angular
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.symmetric = symmetric

        total_padding = kernel_size - 1
        left_pad = total_padding // 2
        self._padding = (left_pad, total_padding - left_pad)

        if symmetric:
            effective_kernel_size = (kernel_size + 1) // 2
        else:
            effective_kernel_size = kernel_size

        total_out_channels = num_radial * out_channels
        total_in_channels = num_radial * in_channels
        self.weight = nn.Parameter(torch.empty((total_out_channels, total_in_channels,
            effective_kernel_size), **factory_kwargs))

        if bias:
            self.bias = nn.Parameter(torch.empty(num_radial * out_channels, **factory_kwargs))
        else:
            self.register_parameter("bias", None)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            if fan_in != 0:
                bound = 1 / math.sqrt(fan_in)
                nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        weight = self.weight

        if self.symmetric:
            weight = torch.concatenate((torch.flip(weight[..., 1:], dims=(-1,)),
                weight), axis=-1)

        num_images, in_channels, num_radial, num_angular = x.shape

        assert in_channels == self.in_channels
        assert num_radial == self.num_radial
        assert num_angular == self.num_angular

        x_1d = torch.reshape(x, (num_images, in_channels * num_radial,
            num_angular))

        x_1d_pad = nn.functional.pad(x_1d, self._padding, mode="circular")

        y_1d = nn.functional.conv1d(x_1d_pad, weight, self.bias)

        out_channels = self.out_channels

        y = torch.reshape(y_1d, (num_images, out_channels, num_radial,
            num_angular))

        return y


class BandedAngularConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, bias,
            bands=1, polar_params=dict(), symmetric=False, device=None,
            dtype=None):
        super(BandedAngularConv, self).__init__()
        factory_kwargs = {"device": device, "dtype": dtype}

        assert bands % 2 == 1

        polar_params = util.default_rbf_params(polar_params)
        num_radial, num_angular = util._polar_grid_size(polar_params)

        self.num_radial = num_radial
        self.num_angular = num_angular
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.bands = bands
        self.symmetric = symmetric

        total_padding = kernel_size - 1
        left_pad = total_padding // 2
        self._padding = (left_pad, total_padding - left_pad)

        if symmetric:
            effective_kernel_size = (kernel_size + 1) // 2
        else:
            effective_kernel_size = kernel_size

        weight_bands = []
        for band_idx in range(bands):
            band = band_idx - (bands - 1) // 2
            band_size = num_radial - abs(band)
            weight_band = nn.Parameter(torch.empty((out_channels,
                in_channels, effective_kernel_size, band_size),
                **factory_kwargs))
            self.register_parameter(f"weight_band{band_idx:d}",
                    weight_band)

        if bias:
            self.bias = nn.Parameter(torch.empty(num_radial * out_channels, **factory_kwargs))
        else:
            self.register_parameter("bias", None)

        self.reset_parameters()

    def reset_parameters(self):
        for band_idx in range(self.bands):
            band = band_idx - (self.bands - 1) // 2
            bound = self.out_channels * self.bands * self.kernel_size
            for band_idx in range(self.bands):
                weight_band = getattr(self, f"weight_band{band_idx:d}")
                nn.init.uniform_(weight_band, -bound, bound)

        if self.bias is not None:
            dummy = torch.empty(self.num_radial * self.out_channels,
                    self.num_radial * self.in_channels, self.kernel_size)
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(dummy)
            if fan_in != 0:
                bound = 1 / math.sqrt(fan_in)
                nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        weight_bands = []

        for band_idx in range(self.bands):
            weight_bands.append(getattr(self, f"weight_band{band_idx:d}"))

        effective_kernel_size = weight_bands[0].shape[-2]
        device = weight_bands[0].device
        dtype = weight_bands[0].dtype

        weight = torch.zeros((self.out_channels, self.num_radial,
            self.in_channels, self.num_radial, effective_kernel_size),
            device=device, dtype=dtype)

        for band_idx in range(self.bands):
            band = band_idx - (self.bands - 1) // 2
            weight += torch.diag_embed(weight_bands[band_idx],
                                       offset=band, dim1=-4, dim2=-2)

        weight = torch.reshape(weight, (self.out_channels * self.num_radial,
            self.in_channels * self.num_radial, effective_kernel_size))

        if self.symmetric:
            weight = torch.concatenate((torch.flip(weight[..., 1:], dims=(-1,)),
                weight), axis=-1)

        num_images, in_channels, num_radial, num_angular = x.shape

        assert in_channels == self.in_channels
        assert num_radial == self.num_radial
        assert num_angular == self.num_angular

        x_1d = torch.reshape(x, (num_images, in_channels * num_radial,
            num_angular))

        x_1d_pad = nn.functional.pad(x_1d, self._padding, mode="circular")

        y_1d = nn.functional.conv1d(x_1d_pad, weight, self.bias)

        out_channels = self.out_channels

        y = torch.reshape(y_1d, (num_images, out_channels, num_radial,
            num_angular))

        return y


class AngularBatchNorm(nn.Module):
    def __init__(self, num_channels, polar_params=dict(), **kwargs):
        super(AngularBatchNorm, self).__init__()

        polar_params = util.default_rbf_params(polar_params)
        num_radial, num_angular = util._polar_grid_size(polar_params)

        self.num_channels = num_channels
        self.num_radial = num_radial
        self.num_angular = num_angular

        self._model = nn.BatchNorm1d(num_channels * num_radial, **kwargs)

    def forward(self, x):
        num_images, num_channels, num_radial, num_angular = x.shape

        assert num_channels == self.num_channels
        assert num_radial == self.num_radial
        assert num_angular == self.num_angular

        x_1d = torch.reshape(x, (num_images, num_channels * num_radial, num_angular))

        y_1d = self._model(x_1d)

        y = torch.reshape(y_1d, (num_images, num_channels, num_radial,
            num_angular))

        return y

class RbfChannelGroupingStrategy(Enum):
    RadiiFirst = 'r'
    ChannelsFirst = 'c'

class AngularGroupNorm(nn.Module):
    def __init__( self, num_channels
                , num_groups
                , grouping_strategy=RbfChannelGroupingStrategy.RadiiFirst
                , polar_params=dict()
                , **kwargs ):
        super().__init__()

        polar_params = util.default_rbf_params(polar_params)
        num_radial, num_angular = util._polar_grid_size(polar_params)

        self.num_channels = num_channels
        assert(num_channels % num_groups == 0
              ), f"Cannot group {num_channels} channels into {num_groups} groups."
        self.num_groups = num_groups
        self.grouping_strategy = grouping_strategy
        self.num_radial = num_radial
        self.num_angular = num_angular

        self._model = nn.GroupNorm(num_groups, num_channels * num_radial, **kwargs)

    def forward(self, x):
        orig_shape = x.shape
        num_images, num_channels, num_radial, num_angular = orig_shape

        assert num_channels == self.num_channels
        assert num_radial == self.num_radial
        assert num_angular == self.num_angular

        groupable_shape = (num_images, num_channels * num_radial, num_angular)

        match self.grouping_strategy:
            case RbfChannelGroupingStrategy.RadiiFirst:
                x_1d = torch.reshape(x, groupable_shape)
            case RbfChannelGroupingStrategy.ChannelsFirst:
                swapped_shape = (num_images, num_radial, num_channels, num_angular)
                x_1d = torch.reshape(torch.transpose(x, -2, -1), groupable_shape)

        y_1d = self._model(x_1d)

        match self.grouping_strategy:
            case RbfChannelGroupingStrategy.RadiiFirst:
                y = torch.reshape(y_1d, orig_shape)
            case RbfChannelGroupingStrategy.ChannelsFirst:
                y = torch.transpose(torch.reshape(y_1d, swapped_shape), -1, -2)

        return y


class RbfPolarDnCnn(nn.Module):
    """Network architecture that can train denoising of single projection
    images. It is equivariant under rotations but avoids the pitfalls
    of a naïve polar grid or of non-local decompositions."""
    def __init__(self, polar_params=dict(), depth=17, n_channels=64, in_channels=1,
            out_channels=1, kernel_size=3, bands=None, symmetric=False,
            batch_norm=True, group_norm_num_groups=None,
            residual=False, cartesian=True):
        super(RbfPolarDnCnn, self).__init__()

        layers = []

        polar_params = util.default_rbf_params(polar_params)

        L = polar_params["L"]
        num_radial, num_angular = util._polar_grid_size(polar_params)

        if cartesian:
            layers.append(CartesianToRbfPolar(polar_params))

        if bands is None:
            Conv = lambda *args, **kwargs: AngularConv(*args, **kwargs)
        else:
            Conv = lambda *args, **kwargs: BandedAngularConv(bands=bands,
                    *args, **kwargs)

        layers.append(Conv(in_channels=in_channels,
            out_channels=n_channels, kernel_size=kernel_size,
            bias=True, symmetric=symmetric, polar_params=polar_params))
        layers.append(nn.ReLU(inplace=True))

        for _ in range(depth - 2):
            layers.append(Conv(in_channels=n_channels,
                out_channels=n_channels, kernel_size=kernel_size,
                bias=False, symmetric=symmetric, polar_params=polar_params))
            if batch_norm:
                assert(group_norm_num_groups is None)
                layers.append(AngularBatchNorm(n_channels, polar_params=polar_params))
            if group_norm_num_groups is not None:
                layers.append(AngularGroupNorm(n_channels, num_groups=group_norm_num_groups,
                                               polar_params=polar_params))

            layers.append(nn.ReLU(inplace=True))

        layers.append(Conv(in_channels=n_channels,
            out_channels=out_channels, kernel_size=kernel_size, bias=False,
            symmetric=symmetric, polar_params=polar_params))

        if cartesian:
            layers.append(RbfPolarToCartesian(polar_params))

        self._model = nn.Sequential(*layers)

        self.residual = residual

        if self.residual:
            mask = RbfPolarToCartesian()(CartesianToRbfPolar()(torch.ones(1,
                1, L, L)))

            self.register_buffer("residual_mask", mask)

        self._initialize_weights()

    def forward(self, x):
        z = self._model(x)

        if self.residual:
            y = self.residual_mask * x - z
        else:
            y = z

        return y

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, AngularConv):
                torch.nn.init.orthogonal_(m.weight)
                if m.bias is not None:
                    torch.nn.init.constant_(m.bias, 0)
            if isinstance(m, BandedAngularConv):
                for band_idx in range(m.bands):
                    weight_band = getattr(m, f"weight_band{band_idx:d}")
                    torch.nn.init.orthogonal_(weight_band)
                if m.bias is not None:
                    torch.nn.init.constant_(m.bias, 0)


def _set_to_singletons(x):
    batch_set_shape = x.shape[:2]
    data_shape = x.shape[2:]

    y = x.reshape((np.prod(batch_set_shape),) + tuple(data_shape))

    return y, batch_set_shape


def _singletons_to_set(x, batch_set_shape):
    data_shape = x.shape[1:]

    y = x.reshape(tuple(batch_set_shape) + tuple(data_shape))

    return y


def _apply_to_set(model, x):
    x, batch_set_shape = _set_to_singletons(x)
    y = model(x)
    y = _singletons_to_set(y, batch_set_shape)

    return y


class SetRbfPolarDnCnn(nn.Module):
    is_set_model = True

    def __init__(self, *args, **kwargs):
        super(SetRbfPolarDnCnn, self).__init__()

        self._model = RbfPolarDnCnn(*args, **kwargs)

    def forward(self, x):
        return _apply_to_set(self._model, x)


class SetSequential(nn.Module):
    is_set_model = True

    def __init__(self, *args, **kwargs):
        super(SetSequential, self).__init__()

        self._model = nn.Sequential(*args, **kwargs)

    def forward(self, x):
        return self._model(x)


class SetCartesianToRbfPolar(nn.Module):
    is_set_model = True

    def __init__(self, *args, **kwargs):
        super(SetCartesianToRbfPolar, self).__init__()

        self._model = CartesianToRbfPolar(*args, **kwargs)

    def forward(self, x):
        return _apply_to_set(self._model, x)


class SetRbfPolarToCartesian(nn.Module):
    is_set_model = True

    def __init__(self, *args, **kwargs):
        super(SetRbfPolarToCartesian, self).__init__()

        self._model = RbfPolarToCartesian(*args, **kwargs)

    def forward(self, x):
        return _apply_to_set(self._model, x)


class SetRbfPolarToOneDim(nn.Module):
    is_set_model = True

    def __init__(self, *args, **kwargs):
        super(SetRbfPolarToOneDim, self).__init__()

        self._model = RbfPolarToOneDim(*args, **kwargs)

    def forward(self, x):
        return _apply_to_set(self._model, x)


class SetConv1d(torch.nn.Conv1d):
    is_set_model = True

    def __init__(self, *args, **kwargs):
        super(SetConv1d, self).__init__(*args, **kwargs)

    def forward(self, x):
        *batch_shape, in_channels, in_length = x.shape
        assert(in_channels == self.weight.shape[1])

        x = x.reshape((np.prod(batch_shape), in_channels, in_length))
        y = super(SetConv1d, self).forward(x)

        out_batch, out_channels, out_length = y.shape
        assert(out_batch == np.prod(batch_shape))

        y = y.reshape(tuple(batch_shape) + (out_channels, out_length))

        return y


class SetBatchNorm1d(torch.nn.Module):
    is_set_model = True

    def __init__(self, num_channels, **kwargs):
        super(SetBatchNorm1d, self).__init__()

        self.num_channels = num_channels

        self._model = torch.nn.BatchNorm2d(num_channels, **kwargs)

    def forward(self, x):
        num_channels = x.shape[2]

        assert num_channels == self.num_channels

        x = torch.permute(x, (0, 2, 1, 3))

        y = self._model(x)

        y = torch.permute(y, (0, 2, 1, 3))

        return y


class SetCircularCnn1d(torch.nn.Module):
    is_set_model = True

    def __init__(self, depth, num_channels, kernel_size, in_channels=1, out_channels=1):
        super(SetCircularCnn1d, self).__init__()

        layers = []

        padding_kwargs = {"padding": "same", "padding_mode": "circular"}

        layers.append(SetConv1d(in_channels, num_channels, kernel_size, **padding_kwargs))
        layers.append(torch.nn.ReLU())

        for _ in range(depth - 2):
            layers.append(SetConv1d(num_channels, num_channels, kernel_size, **padding_kwargs))
            layers.append(SetBatchNorm1d(num_channels))
            layers.append(torch.nn.ReLU())

        layers.append(SetConv1d(num_channels, out_channels, kernel_size, **padding_kwargs))

        self._model = torch.nn.Sequential(*layers)

    def forward(self, x):
        return self._model(x)


class EquivariantAttention(torch.nn.Module):
    is_set_model = True

    def __init__(self, query_model, key_model, value_model, skip=False,
            symmetric_softmax=False, learned_beta=False, mask_diagonal=False):
        super(EquivariantAttention, self).__init__()

        self.query_model = query_model
        self.key_model = key_model
        self.value_model = value_model

        self.skip = skip
        self.symmetric_softmax = symmetric_softmax
        self.mask_diagonal = mask_diagonal

        if learned_beta:
            self.beta = torch.nn.Parameter(torch.ones(1))
        else:
            self.beta = 1

    def forward(self, x, return_extra=False):
        query = self.query_model(x)

        if self.query_model is self.key_model:
            key = query
        else:
            key = self.key_model(x)

        value = self.value_model(x)

        batch_size, set_size, query_channels, num_angular = query.shape
        value_channels, num_radial = value.shape[-3:-1]

        query_f = torch.fft.fft(query, axis=-1)
        key_f = torch.fft.fft(key, axis=-1)

        query_f = query_f.permute((0, 3, 1, 2))
        key_f = key_f.permute((0, 3, 2, 1))

        attention_f = torch.matmul(query_f, key_f.conj())

        attention_f = attention_f.permute((0, 2, 3, 1))

        attention = torch.fft.ifft(attention_f, axis=-1).real

        attention = self.beta * attention
        attention = attention / np.sqrt(query_channels * num_angular)

        if self.symmetric_softmax:
            attention = torch.nn.functional.softmax(attention, dim=-1)
            attention = attention / set_size
        else:
            attention = attention.reshape((batch_size, set_size, set_size * num_angular))
            attention = torch.nn.functional.softmax(attention, dim=-1)
            attention = attention.reshape((batch_size, set_size, set_size, num_angular))

        if self.mask_diagonal:
            attention = attention * (1 - torch.eye(set_size, device=attention.device, dtype=attention.dtype))[None, :, :, None]

        attention_f = torch.fft.fft(attention, axis=-1)
        value_f = torch.fft.fft(value, axis=-1)

        attention_f = attention_f.permute((0, 3, 1, 2))
        value_f = value_f.permute((0, 4, 1, 2, 3))

        value_f = value_f.reshape((batch_size, num_angular, set_size, value_channels * num_radial))

        output_f = torch.matmul(attention_f, value_f)

        output_f = output_f.permute((0, 2, 3, 1))
        output = torch.fft.ifft(output_f, axis=-1).real

        output = output.reshape((batch_size, set_size, value_channels, num_radial, num_angular))

        if self.skip:
            output += x

        if not return_extra:
            return output
        else:
            return output, (query, key, value, attention)


class ScalarMultiplication(torch.nn.Module):
    def __init__(self):
        super(ScalarMultiplication, self).__init__()

        self.alpha = torch.nn.Parameter(torch.ones(1))

    def forward(self, x):
        return self.alpha * x




class PolarTransformer(torch.nn.Module):
    """A denoising model that operates on sets of projection. At least some of these
    should display the same molecule with the same projection axis, regardless
    of plane rotation (to which the model is equivariant)."""

    is_set_model = True

    def __init__(self, polar_params=dict(),
                 query_depth=3, query_hidden_channels=16,
                 query_channels=4, preproc_depth=5, preproc_channels=8, preproc_bands=3,
                 postproc_depth=9, postproc_channels=8, postproc_bands=3, value_channels=8,
                 kernel_size=5, skip=False, symmetric_softmax=False, cartesian=True,
                 batch_norm=True, group_norm_num_groups=None,
                 mask_diagonal=False):
        super().__init__()

        layers = []

        polar_params = util.default_rbf_params(polar_params)

        num_radial, _ = util._polar_grid_size(polar_params)

        if cartesian:
            encoding_model = SetCartesianToRbfPolar(polar_params)
        else:
            encoding_model = torch.nn.Identity()

        preproc_model = SetRbfPolarDnCnn(polar_params,
                depth=preproc_depth, n_channels=preproc_channels,
                kernel_size=kernel_size, bands=preproc_bands,
                out_channels=value_channels,
                batch_norm=batch_norm, group_norm_num_groups=group_norm_num_groups,
                cartesian=False)

        layers.append(SetRbfPolarToOneDim(polar_params))
        layers.append(SetCircularCnn1d(query_depth,
            query_hidden_channels, kernel_size, in_channels=value_channels * num_radial,
            out_channels=query_channels))

        query_model = SetSequential(*layers)
        key_model = query_model

        if skip:
            value_model = ScalarMultiplication()
        else:
            value_model = torch.nn.Identity()

        attention_model = EquivariantAttention(query_model, key_model,
                value_model, skip=skip, symmetric_softmax=symmetric_softmax,
                mask_diagonal=mask_diagonal)

        postproc_model = SetRbfPolarDnCnn(polar_params,
                depth=postproc_depth, n_channels=postproc_channels,
                kernel_size=kernel_size, bands=postproc_bands,
                in_channels=value_channels,
                batch_norm=batch_norm, group_norm_num_groups=group_norm_num_groups,
                cartesian=False)

        if cartesian:
            decoding_model = SetRbfPolarToCartesian(polar_params)
        else:
            decoding_model = torch.nn.Identity()

        model = SetSequential(encoding_model, preproc_model, attention_model, postproc_model, decoding_model)

        self.encoding_model = encoding_model
        self.preproc_model = preproc_model
        self.attention_model = attention_model
        self.postproc_model = postproc_model
        self.decoding_model = decoding_model
        self.model = model

    def forward(self, x):
        return self.model(x)


