import einops
import torch
from torch import nn

from kappamodules.layers import ContinuousSincosEmbed, LinearProjection

class SupernodePoolingMock(nn.Module):
    """Mocks the supernode pooling
    """

    def __init__(
            self,
            radius: float,
            hidden_dim: int,
            ndim: int,
            field_dim: int,
            mode: str = "abspos",
            init_weights: str = "torch",
            **kwargs
    ):
        super().__init__()
        self.radius = radius
        self.hidden_dim = hidden_dim
        self.ndim = ndim
        self.field_dim = field_dim
        self.init_weights = init_weights
        self.mode = mode

        self.pos_embed = ContinuousSincosEmbed(dim=hidden_dim, ndim=ndim)
        if mode == "abspos":
            message_input_dim = hidden_dim * 2
            self.rel_pos_embed = None
        elif mode == "absrelpos":
            message_input_dim = hidden_dim * 3
            self.rel_pos_embed = ContinuousSincosEmbed(dim=hidden_dim, ndim=ndim + 1)
        elif mode == "relpos":
            message_input_dim = hidden_dim
            self.rel_pos_embed = ContinuousSincosEmbed(dim=hidden_dim, ndim=ndim + 1)
        else:
            raise NotImplementedError
        self.message = nn.Sequential(
            LinearProjection(message_input_dim, hidden_dim, init_weights=init_weights),
            nn.GELU(),
            LinearProjection(hidden_dim, hidden_dim, init_weights=init_weights),
        )
        self.field_proj = LinearProjection(self.field_dim, hidden_dim)
            
        self.output_dim = hidden_dim

    def forward(
            self,
            field: torch.Tensor,
            input_pos: torch.Tensor,
            supernode_index: torch.Tensor,
            batch_index: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        assert input_pos.ndim == 2
        assert supernode_index.ndim == 1

        # # radius graph
        # if self.implementation == "radius":
        #     if batch_index is None:
        #         batch_y = None
        #     else:
        #         batch_y = batch_index[supernode_index]
        #     edges = torch_geometric.nn.pool.radius(
        #         x=input_pos,
        #         y=input_pos[supernode_index],
        #         r=self.radius,
        #         max_num_neighbors=64, # postprocess neighbors
        #         batch_x=batch_index,
        #         batch_y=batch_y,
        #     )
        #     # subsample
        #     subsample_idcs = vec_sample_indices_from_classes(edges[0,:], self.max_degree)
        #     edges = edges[:,subsample_idcs]
        #     # remap dst indices
        #     dst_idx, src_idx = edges.unbind()
        #     dst_idx = supernode_index[dst_idx]
        # elif self.implementation == "radius_graph":
        #     edges = torch_geometric.nn.pool.radius_graph(
        #         x=input_pos,
        #         r=self.radius,
        #         max_num_neighbors=None,
        #         batch=batch_index,
        #         loop=True,
        #         # inverted flow direction is required to have sorted dst_indices
        #         flow="target_to_source",
        #     )
        #     is_supernode_edge = torch.isin(edges[0], supernode_index)
        #     edges = edges[:, is_supernode_edge]
        #     dst_idx, src_idx = edges.unbind()
        #     assert False, 'subsampling not implemented nor tested'
        # else:
        #     raise NotImplementedError

        field_emb = self.field_proj(field)
        x = self.pos_embed(input_pos)
        x = x + field_emb # add field embedding
        
        # # create message
        # if self.mode == "abspos":
        #     x = self.pos_embed(input_pos)
        #     supernode_pos_embed = x[supernode_index]
        #     x = x + field_emb # add field embedding
        #     x = torch.concat([x[src_idx], x[dst_idx]], dim=1)
        # elif self.mode == "absrelpos":
        #     src_pos = input_pos[src_idx]
        #     dst_pos = input_pos[dst_idx]
        #     dist = (dst_pos - src_pos)
        #     mag = dist.norm(dim=1).unsqueeze(-1)
        #     x = self.pos_embed(input_pos)
        #     supernode_pos_embed = x[supernode_index]
        #     relemb = self.rel_pos_embed(torch.concat([dist, mag], dim=1))
        #     x = x + field_emb # add field embedding
        #     x = torch.concat([x[src_idx], x[dst_idx], relemb], dim=1)
        # elif self.mode == "relpos":
        #     src_pos = input_pos[src_idx]
        #     dst_pos = input_pos[dst_idx]
        #     dist = (dst_pos - src_pos)
        #     mag = dist.norm(dim=1).unsqueeze(-1)
        #     x = self.rel_pos_embed(torch.concat([dist, mag], dim=1))
        #     supernode_pos_embed = self.pos_embed(input_pos[supernode_index])
        #     x = x + field_emb
        #     raise NotImplementedError(f'double-check that adding the field values is properly implementedZ')
        # else:
        #     raise ValueError('positional embedding not correctly specified')
                
        # x = self.message(x)

        # accumulate messages
        # indptr is a tensor of indices betweeen which to aggregate
        # i.e. a tensor of [0, 2, 5] would result in [src[0] + src[1], src[2] + src[3] + src[4]]
        # if self.implementation == "radius":
        #     dst_indices, counts = dst_idx.unique_consecutive(return_counts=True)
        #     assert torch.all(supernode_index == dst_indices)
        # elif self.implementation == "radius_graph":
        #     dst_indices, counts = dst_idx.unique(return_counts=True)
        # else:
        #     raise NotImplementedError
        # first index has to be 0
        # NOTE: padding for target indices that don't occour is not needed as self-loop is always present
        # padded_counts = torch.zeros(len(counts) + 1, device=counts.device, dtype=counts.dtype)
        # padded_counts[1:] = counts
        # indptr = padded_counts.cumsum(dim=0)
        # # import here to avoid torch_scatter throwing import errors if it's not in the environment
        # import torch_scatter
        # x = torch_scatter.segment_csr(src=x, indptr=indptr, reduce=self.aggregation)

        # sanity check: dst_indices has len of batch_size * num_supernodes and has to be divisible by batch_size
        # if num_supernodes is not set in dataset this assertion fails
        if batch_index is None:
            batch_size = 1
        else:
            batch_size = batch_index.max() + 1
            # assert dst_indices.numel() % batch_size == 0

        # calculate metrics for logging
        # degree = src_idx.numel() / supernode_index.numel()
        # coverage = src_idx.unique().numel() / len(input_pos)

        # if self.readd_supernode_pos:
        #     # supernode_pos_embed = einops.rearrange(
        #     #     supernode_pos_embed,
        #     #     "(batch_size num_supernodes) dim -> batch_size num_supernodes dim",
        #     #     batch_size=len(x),
        #     # )
        #     # concatenate input and supernode embeddings
        #     x = torch.concat([x, supernode_pos_embed], dim=-1)
        #     x = self.proj(x)
            
        # if self.readd_field:
        #     field_emb_residual = self.field_proj_skip(field)
        #     x = x + field_emb_residual

        # convert to dense tensor (dim last)
        x = einops.rearrange(
            x,
            "(batch_size num_supernodes) dim -> batch_size num_supernodes dim",
            batch_size=batch_size,
        )

        return dict(
            main=x,
            degree=0.,
            coverage=0.,
        )