import torch
import torch.nn as nn

from .rapidash_fiber_grids import GridGenerator_S1_S2, GridGenerator_SO3
from .rapidash_invariants import Invariants
from .rapidash_modules import (ConvNext, PolynomialFeatures,
                               SeparableFiberBundleConvNext)
from .rapidash_utils import (fps_edge_index, fully_connected_edge_index,
                             knn_graph, scatter_add)


class Rapidash(nn.Module):
    """Steerable E(3) equivariant (non-linear) convolutional network"""

    _supported_edge_types = ["fc", "knn"]

    def __init__(
        self,
        input_dim,
        hidden_dim,
        output_dim,
        num_layers,
        edge_types=["fc"],
        ratios=[],
        output_dim_vec=0,
        dim=3,
        base_space="R3",
        fiber_space="S2",
        fiber_dim=12,
        equivariance="SEn",
        basis_dim=None,
        degree=2,
        widening_factor=4,
        layer_scale=None,
        task_level="node",
        multiple_readouts=False,
        last_feature_conditioning=True,
        attention=False,
        fully_connected=False,
        residual_connections=True,
        global_basis=False,
        **kwargs
    ):
        super().__init__()

        # Pair-wise invariant attribute module
        if fiber_dim == 0:
            fiber_space = None
        self.compute_invariants = Invariants(base_space, equivariance, fiber_space)
        # And whether or not to use the same basis functions per layer
        self.global_basis = global_basis

        # Layer settings
        self.residual_connections = residual_connections
        self.last_feature_conditioning = last_feature_conditioning

        self.layers_per_scale = [num_layers] if type(num_layers) is int else num_layers
        n = len(self.layers_per_scale)

        self.edge_types = n * [edge_types] if type(edge_types) is str else edge_types
        self.ratios = n * [ratios] if type(ratios) is float else ratios + [1.0]
        self.hidden_dims = n * [hidden_dim] if type(hidden_dim) is int else hidden_dim
        self.hidden_dim = self.hidden_dims[-1]
        self._parse_edge_types()

        assert len(self.layers_per_scale) == len(self.edge_types) == len(self.ratios)
        self.num_scales = len(self.layers_per_scale)

        # Set in / out dimensions
        self.in_out_dims_per_layer = []
        if not (task_level == "graph") and self.num_scales > 1:
            self.up_sample = True
            self.effective_num_layers = (
                sum(self.layers_per_scale) + self.num_scales - 1
            ) * 2
            self.in_out_dims_per_layer += self.in_out_dims_per_layer[::-1]
        else:
            self.up_sample = False
            self.effective_num_layers = sum(self.layers_per_scale) + self.num_scales - 1

        # Construct the grids
        self.fiber_dim = fiber_dim
        if self.fiber_dim != 0:
            if fiber_space in ["S1", "S2"]:
                self.grid_generator = GridGenerator_S1_S2(dim, fiber_dim, steps=1000)
            elif fiber_space == "SO3":
                self.grid_generator = GridGenerator_SO3(dim, fiber_dim, steps=1000)
            self.register_buffer('fiber_grid', self.grid_generator())

            ConvBlock = SeparableFiberBundleConvNext
        else:
            self.register_buffer('fiber_grid', None)
            ConvBlock = ConvNext


        # Input output settings
        self.output_dim, self.output_dim_vec = output_dim, output_dim_vec
        self.global_pooling = task_level == "graph"

        # Activation function to use internally
        act_fn = nn.GELU()

        # Kernel basis functions and spatial window
        basis_dim = hidden_dim if (basis_dim is None) else basis_dim

        self.basis_fn = nn.ModuleList(
            nn.Sequential(
                PolynomialFeatures(degree),
                nn.LazyLinear(self.hidden_dim),
                act_fn,
                nn.Linear(self.hidden_dim, basis_dim),
                act_fn,
            )
            for _ in range(1 if global_basis else len(self.ratios))
        )
        if self.fiber_dim != 0:
            self.fiber_basis_fn = nn.ModuleList(
                nn.Sequential(
                    PolynomialFeatures(degree),
                    nn.LazyLinear(self.hidden_dim),
                    act_fn,
                    nn.Linear(self.hidden_dim, basis_dim),
                    act_fn,
                )
                for _ in range(1 if global_basis else len(self.ratios))
            )

        # Initial node embedding
        self.x_embedder = nn.Linear(
            input_dim + last_feature_conditioning, self.hidden_dims[0], False
        )

        # Make feedforward network
        self.interaction_layers = nn.ModuleList()
        self.read_out_layers = nn.ModuleList()

        hidden_dims = list(
            map(
                int,
                torch.repeat_interleave(
                    torch.Tensor(self.hidden_dims), torch.tensor(self.layers_per_scale).int() + 1
                ).tolist(),
            )
        )

        if self.up_sample:
            hidden_dims = hidden_dims + hidden_dims[:-1][::-1]

        for i in range(self.effective_num_layers):
            in_channels, out_channels = hidden_dims[i : i + 2]
            print('layer:', i, 'in channels:', in_channels, 'out channels:', out_channels)
            self.interaction_layers.append(
                ConvBlock(
                    in_channels,
                    basis_dim,
                    out_channels=out_channels,
                    act=act_fn,
                    layer_scale=layer_scale,
                    widening_factor=widening_factor,
                    attention=attention,
                )
            )
            if multiple_readouts or i == (self.effective_num_layers - 1):
                self.read_out_layers.append(
                    nn.Linear(out_channels, output_dim + output_dim_vec)
                )
            else:
                self.read_out_layers.append(None)

    def _parse_edge_types(self):
        edge_types = []
        edge_types_kwargs = []

        for edge_type in self.edge_types:
            if edge_type.lower() == "fc":
                edge_types.append(edge_type.lower())
                edge_types_kwargs.append({})
                continue

            edge_type, edge_type_kwargs = edge_type.lower().split("-")

            if edge_type == "knn":
                edge_types.append(edge_type)
                edge_types_kwargs.append({"k": int(edge_type_kwargs)})
            else:
                raise ValueError("Given edge type not in:", self._supported_edge_types)

        self.edge_types = edge_types
        self.edge_types_kwargs = edge_types_kwargs

    def precompute_interaction_layers(self, edge_type, edge_type_kwargs, pos, batch):
        if edge_type == "fc":
            return fully_connected_edge_index(batch)
        elif edge_type == "knn":
            return knn_graph(pos, batch=batch, loop=True, **edge_type_kwargs).flip(0)

    def precompute_interaction_transition_layers(
        self, pos, batch, spatial_cond=None
    ):
        data_per_layer = []
        data_per_layer_up = []

        basis_idx = 0

        for i in range(0, self.num_scales):
            edge_type, edge_type_kwargs = self.edge_types[i], self.edge_types_kwargs[i]

            ratio = self.ratios[i]

            edge_index = self.precompute_interaction_layers(
                edge_type, edge_type_kwargs, pos, batch
            )

            spatial_invariants, fiber_invariants = self.compute_invariants(
                pos[edge_index[0]], pos[edge_index[1]], self.fiber_grid
            )

            if spatial_cond is not None:
                cond = spatial_cond[edge_index[0]].repeat(1, self.fiber_dim, 1)
                spatial_invariants = torch.cat((spatial_invariants, cond), dim=-1)

            kernel_basis = self.basis_fn[basis_idx](spatial_invariants)
            if self.fiber_grid is not None:
                fiber_kernel_basis = self.fiber_basis_fn[basis_idx](fiber_invariants)
            else:
                fiber_kernel_basis = None

            data_per_layer += [
                (kernel_basis, fiber_kernel_basis, edge_index, batch)
            ] * self.layers_per_scale[i]

            if self.up_sample:
                data_per_layer_up = [
                    (kernel_basis, fiber_kernel_basis, edge_index, batch)
                ] * self.layers_per_scale[i] + data_per_layer_up

            # Transition layer
            if ratio < 1.0 and i < self.num_scales - 1:
                edge_index, fps_pos, fps_batch = fps_edge_index(pos, batch, ratio=ratio)
                spatial_invariants, fiber_invariants = self.compute_invariants(
                    pos[edge_index[0]], fps_pos[edge_index[1]], self.fiber_grid
                )

                if spatial_cond is not None:
                    cond = spatial_cond[edge_index[0]].repeat(1, self.fiber_dim, 1)
                    spatial_invariants = torch.cat((spatial_invariants, cond), dim=-1)

                kernel_basis = self.basis_fn[basis_idx](spatial_invariants)
                if self.fiber_grid is not None:
                    fiber_kernel_basis = self.fiber_basis_fn[basis_idx](
                        fiber_invariants
                    )
                else:
                    fiber_kernel_basis = None

                pos, batch = fps_pos, fps_batch

                data_per_layer.append(
                    (kernel_basis, fiber_kernel_basis, edge_index, batch)
                )

                if self.up_sample:
                    data_per_layer_up = [
                        (kernel_basis, fiber_kernel_basis, edge_index.flip(0), batch)
                    ] + data_per_layer_up

                basis_idx += 0 if self.global_basis else 1

        return data_per_layer + data_per_layer_up

    def forward(self, x, pos, edge_index, batch=None, vec=None):
        fiber_grid = self.fiber_grid

        # Precompute the interaction and transition layers
        data_per_layer = self.precompute_interaction_transition_layers(
            pos,
            batch,
            spatial_cond=x[..., None, -1:] if self.last_feature_conditioning else None,
        )

        # Initial feature embeding
        x_list = []
        if fiber_grid is not None:
            if x is not None:
                # Lift the scalars ([B*N,Cs]->[B*N,O,Cs]):
                x_list.append(x.unsqueeze(-2).repeat_interleave(self.fiber_dim, dim=-2))  # [B*N,O,C]
            if vec is not None:
                # lift the vectors ([B*N,C,3]->[B*N,O,C])
                x_list.append(torch.einsum('bcd,nd->bnc', vec, fiber_grid))
            x = torch.cat(x_list, dim=-1)

        # Embed the inputs
        x = self.x_embedder(x)

        # Interaction + transition + readout
        readouts = []
        residuals = []
        for i in range(self.effective_num_layers):
            residual = x
            kernel, fiber_kernel, edge_index, batch = data_per_layer[i]
            x = self.interaction_layers[i](x, kernel, fiber_kernel, edge_index)
            if self.residual_connections:
                # downsampling, so save residual
                if residual.shape[0] > x.shape[0]:
                    residuals.append(residual)
                # upsampling, so take residual
                elif residual.shape[0] < x.shape[0]:
                    residual = residuals.pop(-1)
                    x = x + residual
            if self.read_out_layers[i] is not None:
                readouts.append(self.read_out_layers[i](x))
        readout = sum(readouts) / len(readouts)

        # Split the output into scalar and vector predictions
        readout_scalar, readout_vec = torch.split(readout, [self.output_dim, self.output_dim_vec], dim=-1)

        # Read out scalar and vector predictions
        if (fiber_grid is not None):
            output_scalar = readout_scalar.mean(dim=-2)  # [B*N,C]
            if (readout_vec.numel() != 0):
                # TODO: does not work for SO3 fiber grid
                output_vector = (torch.einsum("boc,od->bcd", readout_vec, fiber_grid) / self.fiber_dim)  # [B*N,C,3]
            else:
                output_vector = None
        else:
            output_scalar = readout_scalar
            output_vector = None

        if self.global_pooling:
            output_scalar = scatter_add(
                src=output_scalar, index=batch, dim_size=batch.max().item() + 1
            ) if output_scalar is not None else None
            output_vector = scatter_add(
                src=output_vector, index=batch, dim_size=batch.max().item() + 1
            ) if output_vector is not None else None

        # Return predictions
        return output_scalar, output_vector
