import einops
import torch
import torch_geometric
from torch import nn

from kappamodules.layers import ContinuousSincosEmbed, LinearProjection

from src.utils.subsample import vec_sample_indices_from_classes


class SupernodePoolingSubsampled(nn.Module):
    """Supernode pooling layer with the following default settings:
    - Messages from incoming nodes are averaged (aggregation="mean")
    - Linear layers are initialized with default torch initialization (init_weights="torch")
    - Messages use the absolute positional embeddings (mode="abspos")
    - After the message passing, the original positional embedding of the supernode is concated to the supernode
      followed by downprojection to the original hidden dimension
    The permutation of the supernodes is preserved through the message passing (contrary to the (GP-)UPT code).
    Additionally, radius is used instead of radius_graph, which is more efficient.
    Args:
        radius (float): Radius around each supernode. From points within this radius, messages are passed to the
          supernode.
        hidden_dim (int): Hidden dimension for positional embeddings, messages and the resulting output vector.
        ndim (int): Number of positional dimension (e.g., ndim=2 for a 2D position, ndim=3 for a 3D position)
        max_degree (int): Maximum degree of the radius graph.
        mode (str): Are positions embedded in absolute space ("abspos"), relative space ("relpos") or both
          ("absrelpos"). "readd_supernode_pos" will always use the absolute position.
        init_weights (str): Weight initialization of linear layers.
        readd_supernode_pos (bool): If true, the absolute positional encoding of the supernode is concated to the
          supernode vector after message passing and linearly projected back to hidden_dim.
        aggregation (str): Aggregation for message passing ("mean" or "sum").
        implementation (str): Whether to use the "radius_graph" implementation of the (GP-)UPT codebase (does not
          preserve ordering of supernodes) or the "radius" implementation (faster and supernode order preserving).
    """

    def __init__(
            self,
            radius: float,
            hidden_dim: int,
            ndim: int,
            field_dim: int,
            max_degree: int = 8,
            mode: str = "abspos",
            init_weights: str = "torch",
            readd_supernode_pos: bool = True,
            readd_field: bool = True,
            aggregation: str = "mean",
            implementation: str = "radius"
    ):
        super().__init__()
        self.radius = radius
        self.max_degree = max_degree
        self.hidden_dim = hidden_dim
        self.ndim = ndim
        self.field_dim = field_dim
        self.init_weights = init_weights
        self.mode = mode
        self.readd_supernode_pos = readd_supernode_pos
        self.readd_field = readd_field
        self.aggregation = aggregation
        self.implementation = implementation

        self.pos_embed = ContinuousSincosEmbed(dim=hidden_dim, ndim=ndim)
        if mode == "abspos":
            message_input_dim = hidden_dim * 2
            self.rel_pos_embed = None
        elif mode == "absrelpos":
            message_input_dim = hidden_dim * 3
            self.rel_pos_embed = ContinuousSincosEmbed(dim=hidden_dim, ndim=ndim + 1)
        elif mode == "relpos":
            message_input_dim = hidden_dim
            self.rel_pos_embed = ContinuousSincosEmbed(dim=hidden_dim, ndim=ndim + 1)
        else:
            raise NotImplementedError
        self.message = nn.Sequential(
            LinearProjection(message_input_dim, hidden_dim, init_weights=init_weights),
            nn.GELU(),
            LinearProjection(hidden_dim, hidden_dim, init_weights=init_weights),
        )
        self.field_proj = LinearProjection(self.field_dim, hidden_dim)
        if readd_supernode_pos:
            self.proj = LinearProjection(2 * hidden_dim, hidden_dim, init_weights=init_weights)
        else:
            self.proj = None
        if self.readd_field:
            self.field_proj_skip = LinearProjection(self.field_dim, hidden_dim)
            
        self.output_dim = hidden_dim

    def forward(
            self,
            field: torch.Tensor,
            input_pos: torch.Tensor,
            supernode_index: torch.Tensor,
            batch_index: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        assert input_pos.ndim == 2
        assert supernode_index.ndim == 1

        # radius graph
        if self.implementation == "radius":
            if batch_index is None:
                batch_y = None
            else:
                batch_y = batch_index[supernode_index]
            edges = torch_geometric.nn.pool.radius(
                x=input_pos,
                y=input_pos[supernode_index],
                r=self.radius,
                max_num_neighbors=64, # postprocess neighbors
                batch_x=batch_index,
                batch_y=batch_y,
            )
            # subsample
            subsample_idcs = vec_sample_indices_from_classes(edges[0,:], self.max_degree)
            edges = edges[:,subsample_idcs]
            # remap dst indices
            dst_idx, src_idx = edges.unbind()
            dst_idx = supernode_index[dst_idx]
        elif self.implementation == "radius_graph":
            edges = torch_geometric.nn.pool.radius_graph(
                x=input_pos,
                r=self.radius,
                max_num_neighbors=None,
                batch=batch_index,
                loop=True,
                # inverted flow direction is required to have sorted dst_indices
                flow="target_to_source",
            )
            is_supernode_edge = torch.isin(edges[0], supernode_index)
            edges = edges[:, is_supernode_edge]
            dst_idx, src_idx = edges.unbind()
            assert False, 'subsampling not implemented nor tested'
        else:
            raise NotImplementedError

        field_emb = self.field_proj(field)

        # create message
        if self.mode == "abspos":
            x = self.pos_embed(input_pos)
            supernode_pos_embed = x[supernode_index]
            x = x + field_emb # add field embedding
            x = torch.concat([x[src_idx], x[dst_idx]], dim=1)
        elif self.mode == "absrelpos":
            src_pos = input_pos[src_idx]
            dst_pos = input_pos[dst_idx]
            dist = (dst_pos - src_pos)
            mag = dist.norm(dim=1).unsqueeze(-1)
            x = self.pos_embed(input_pos)
            supernode_pos_embed = x[supernode_index]
            relemb = self.rel_pos_embed(torch.concat([dist, mag], dim=1))
            x = x + field_emb # add field embedding
            x = torch.concat([x[src_idx], x[dst_idx], relemb], dim=1)
        elif self.mode == "relpos":
            src_pos = input_pos[src_idx]
            dst_pos = input_pos[dst_idx]
            dist = (dst_pos - src_pos)
            mag = dist.norm(dim=1).unsqueeze(-1)
            x = self.rel_pos_embed(torch.concat([dist, mag], dim=1))
            supernode_pos_embed = self.pos_embed(input_pos[supernode_index])
            x = x + field_emb
            raise NotImplementedError(f'double-check that adding the field values is properly implementedZ')
        else:
            raise ValueError('positional embedding not correctly specified')
                
        x = self.message(x)

        # accumulate messages
        # indptr is a tensor of indices betweeen which to aggregate
        # i.e. a tensor of [0, 2, 5] would result in [src[0] + src[1], src[2] + src[3] + src[4]]
        if self.implementation == "radius":
            dst_indices, counts = dst_idx.unique_consecutive(return_counts=True)
            assert torch.all(supernode_index == dst_indices)
        elif self.implementation == "radius_graph":
            dst_indices, counts = dst_idx.unique(return_counts=True)
        else:
            raise NotImplementedError
        # first index has to be 0
        # NOTE: padding for target indices that don't occour is not needed as self-loop is always present
        padded_counts = torch.zeros(len(counts) + 1, device=counts.device, dtype=counts.dtype)
        padded_counts[1:] = counts
        indptr = padded_counts.cumsum(dim=0)
        # import here to avoid torch_scatter throwing import errors if it's not in the environment
        import torch_scatter
        x = torch_scatter.segment_csr(src=x, indptr=indptr, reduce=self.aggregation)

        # sanity check: dst_indices has len of batch_size * num_supernodes and has to be divisible by batch_size
        # if num_supernodes is not set in dataset this assertion fails
        if batch_index is None:
            batch_size = 1
        else:
            batch_size = batch_index.max() + 1
            assert dst_indices.numel() % batch_size == 0

        # calculate metrics for logging
        degree = src_idx.numel() / supernode_index.numel()
        coverage = src_idx.unique().numel() / len(input_pos)

        if self.readd_supernode_pos:
            # supernode_pos_embed = einops.rearrange(
            #     supernode_pos_embed,
            #     "(batch_size num_supernodes) dim -> batch_size num_supernodes dim",
            #     batch_size=len(x),
            # )
            # concatenate input and supernode embeddings
            x = torch.concat([x, supernode_pos_embed], dim=-1)
            x = self.proj(x)
            
        if self.readd_field:
            field_emb_residual = self.field_proj_skip(field)
            x = x + field_emb_residual

        # convert to dense tensor (dim last)
        x = einops.rearrange(
            x,
            "(batch_size num_supernodes) dim -> batch_size num_supernodes dim",
            batch_size=batch_size,
        )

        return dict(
            main=x,
            degree=degree,
            coverage=coverage,
        )