import torch
from torch import nn

from torch_geometric.utils import to_dense_batch

from kappamodules.transformer import DitBlock

from models.registry import register_model
from models.utils import MLP
from models.condition import ContinuousSincosEmbed

from .upt_utils.deep_perceiver_decoder_conditioned import DeepPerceiverDecoderConditioned
from .upt_utils.supernode_pooling import SupernodePooling
from .upt_utils.linear_projection import LinearProjection


@register_model()
class UPT(nn.Module):
    """UPT model with two a field based decoder.

    Args:
        n_conds: Number of conditioning parameters.
        latent_channels: Latent (conditioning) vector dimension.
        supernodes_radius: Radius for supernode pooling.
        output_channels: Number of output features to predict (e.g., 1 for surface pressure prediction).
        upt_base: Hidden dimension to use for the whole model. Typical values are 192, 384 or 768. Default: 192
        app_depth: How many approximator blocks (i.e, transformer blocks to use). Default: 12
        space: Number of dimension in the domain. Typically set to 3 for 3D coordinates. Default: 3
        input_dim: If input positions have additional features (e.g., pressure or velocity in CFD), input_dim defines
            how many additional features are used. If defined, requires input_features in the `forward` method.
        dec_depth: How many decoder blocks to use for the field-based perceiver decoder. If set to 0, only the
            point-based decoder is used. Default: 0
        mlp_expansion_factor: Expansion factor for the MLP layers of transformer/perciever blocks. Default: 4
        num_supernotes: Number of supernodes to sample.
    """
    def __init__(
        self,
        n_conds: int,
        latent_channels: int = 8,
        supernodes_radius: float = 0.05,
        output_channels: int = 5,
        upt_base: int = 192,
        app_depth: int = 12,
        num_heads: int = 3,
        space: int = 3,
        input_dim: int | None = None,
        dec_depth: int = 0,
        mlp_expansion_factor: int = 4,
        num_supernodes: int = 8000,
        supernodes_max_neighbours: int =32,
        out_deformation: bool = False,
        n_materials: int = 1,
    ):
        super().__init__()
        self.num_supernodes = num_supernodes

        self.conditioning = nn.Sequential(
            ContinuousSincosEmbed(dim=256, ndim=n_conds),
            MLP(
                [256, 256 // 2, 256 // 4, latent_channels],
                act_fn=nn.SiLU,
                last_act_fn=nn.SiLU,
                dropout_prob=0.1,
            ),
        )

        # supernode pooling
        self.encoder = SupernodePooling(
            input_dim=input_dim,
            hidden_dim=upt_base,
            ndim=space,
            radius=supernodes_radius,
            init_weights="truncnormal002",
            max_degree=supernodes_max_neighbours,
        )

        # blocks
        self.blocks = nn.ModuleList(
            [
                 DitBlock(
                    dim=upt_base,
                    num_heads=num_heads,
                    cond_dim=latent_channels,
                    drop_path=0
                )
                for _ in range(app_depth)
            ],
        )

        # decoders
        if dec_depth > 0:
            self.field_decoder = DeepPerceiverDecoderConditioned(
                dim=upt_base,
                num_heads=num_heads,
                input_dim=upt_base,
                output_dim=output_channels,
                ndim=space,
                depth=dec_depth,
                mlp_expansion_factor=mlp_expansion_factor,
                init_weights="truncnormal002",
            )
        else:
            self.field_decoder = None
        self.point_decoder = LinearProjection(upt_base, output_channels, init_weights="truncnormal002")

    def _sample_supernodes(self, batch_index: torch.Tensor):
        # # nodes per graph
        num_graphs = int(batch_index.max().item()) + 1
        node_counts = torch.bincount(batch_index, minlength=num_graphs)
        cum_counts = torch.cat(
            [torch.tensor([0], device=batch_index.device), node_counts.cumsum(0)]
        )

        supernode_index = []
        for i in range(num_graphs):
            count = node_counts[i].item()
            if count < self.num_supernodes:
                raise ValueError(f"Graph {i} has only {count} nodes, but {self.num_supernodes} supernodes requested.")
            idx = torch.randperm(count, device=batch_index.device)[:self.num_supernodes]
            supernode_index.append(idx + cum_counts[i])
        supernode_index = torch.cat(supernode_index)
        super_node_batch_index = torch.arange(num_graphs, device=batch_index.device).repeat_interleave(self.num_supernodes)
        return supernode_index, super_node_batch_index

    def forward(
        self,
        cond: torch.Tensor,
        mesh_coords: torch.Tensor,
        mesh_edges: torch.Tensor,
        batch_index: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        """Forward pass for the UPT model with dual decoders.

        Args:
            input_position: Sparse tensor (batch_size * num_points, ndim) contianing the input point clouds.
            supernode_idx: Index tensor (batch_size * num_supernodes,) containing the indices of the supernodes.
            output_position: Positions for the field-based decoder (batch_size, num_output_positions, ndim).
                Not needed if field-based decoder is not used.
            batch_idx: Index tensor that assigns the input_position tensor to its respective sample in the batch.
                Not needed if batch_size=1.
            unbatch_mask_output_position: Mask to remove padding values from the output if variable output sizes are
                produced by the field-based decoder. Not needed if batch_size=1 or if the field-based decoder is not
                used (dec_depth=0).

        Returns:
              dictionary with the key "point" for the point_based prediction (the prediction at the supernode
                locations) and optionally a second key "field" for the field-based prediction if a field-based decoder
                is used (dec_depth > 0).
        """
        coords = mesh_coords.clone()

        # parameter conditioning mlp
        latent_vector = self.conditioning(cond)

        # super node pooling
        supernode_idx, supernode_batch_index = self._sample_supernodes(batch_index=batch_index)

        # encoder
        x = self.encoder(
            input_pos=mesh_coords,
            supernode_idx=supernode_idx,
            batch_idx=batch_index,
        )

        # blocks
        for block in self.blocks:
            x = block(x, latent_vector)

        # must pad coords for perceiver query
        if mesh_coords.ndim == 2:
            # pad to max nodes
            mesh_coords, pad_mask = to_dense_batch(mesh_coords, batch_index)

        # decoders
        outputs = {}

        # field-based decoder
        if self.field_decoder is not None:
            outputs["field"] = self.field_decoder(
                x=x,
                pos=mesh_coords,
                cond=latent_vector,
                unbatch_mask=pad_mask.flatten(),
            )

        return (outputs["field"], coords), latent_vector

