"""
Base classes for diffusion denoising models in MDShortcut.

This module provides the foundational BaseDenoiser class that implements common functionality
for all denoising models, including element embeddings, property conditioning, and
node feature assembly for graph neural networks.
"""
import torch
from torch import nn
from models.modules.embeddings import OneHotElementEmbedding


class BaseDenoiser(nn.Module):
    """Base class for all material diffusion denoising models.
    
    Provides common functionality for denoising models including:
    - Element embedding (one-hot encoding)
    - Property embedding and conditioning
    - Node feature assembly for GNN input
    - Edge computation with cutoff radius
    - Common interfaces for prediction
    
    This class should be subclassed to implement specific architectures (EGNN, etc.).
    
    Attributes:
        cutoff_radius (float): Cutoff distance for atomic interactions.
        elements (list): List of element symbols/numbers to embed.
        node_attrs (bool): Whether to use node attributes in the model.
        properties (dict): Property specifications for conditioning.
        d_prop_embed (int): Default embedding dimension for properties.
        use_dt (bool): Whether to use timestep differences in node features.
        element_embedding (OneHotElementEmbedding): Element embedding layer.
        prop_embed (nn.ModuleDict): Property embedding layers.
        null_prop (nn.ParameterDict): Learnable null property embeddings.
    """

    def __init__(self, cutoff_radius, elements=None, properties=None, d_prop_embed=8, node_attrs=False, use_dt=False):
        """Initialize BaseDenoiser with configuration for embeddings and properties.
        
        Args:
            cutoff_radius (float): Cutoff radius for neighbor list construction and edge computation.
            elements (list, optional): List of element symbols/numbers to create embeddings for.
                If None, no element embeddings are used. Defaults to None.
            properties (dict, optional): Dictionary specifying property conditioning setup.
                Each key is a property name, value is a dict with 'dim', 'd_prop_embed', etc.
                If None, no property conditioning is used. Defaults to None.
            d_prop_embed (int, optional): Default embedding dimension for properties. Defaults to 8.
            node_attrs (bool, optional): Whether to use node attributes in the GNN. Defaults to False.
            use_dt (bool, optional): Whether to include timestep differences in node features.
                Defaults to False.
        """
        super().__init__()
        self.cutoff_radius = cutoff_radius
        self.elements = elements
        self.node_attrs = node_attrs
        self.properties = properties
        self.d_prop_embed = d_prop_embed
        self.use_dt = use_dt

        if self.elements is not None:
            self.element_embedding = OneHotElementEmbedding(self.elements)
        else:
            self.element_embedding = None

        if self.properties is not None:
            self.prop_embed = nn.ModuleDict({
                k: nn.Sequential(
                    nn.Linear(v['dim'], v.get('d_prop_embed', self.d_prop_embed)),
                    nn.LayerNorm(v.get('d_prop_embed', self.d_prop_embed))
                )
                for k, v in self.properties.items()
            })
            self.null_prop = nn.ParameterDict({
                k: nn.Parameter(torch.randn(v.get('d_prop_embed', self.d_prop_embed)), requires_grad=True)
                for k, v in self.properties.items()
            })

    def predict_noise(self, sample, t):
        """Wrapper for forward pass."""
        return self(sample, t)

    @torch.compiler.disable
    def _get_edges(self, sample):
        """Get edges with compiler disabled."""
        return sample.get_edges(self.cutoff_radius)

    def embed_properties(self, sample):
        """Embed sample properties."""
        props = []
        for prop_name in self.properties:
            prop, null_mask = sample.get_property_arr(
                prop_name,
                null_placeholder=torch.zeros(
                    self.properties[prop_name]['dim'],
                    device=sample.positions.device
                )
            )

            if 'offset' in self.properties[prop_name]:
                prop = prop - self.properties[prop_name]['offset']
            if 'scale' in self.properties[prop_name]:
                prop = prop / self.properties[prop_name]['scale']

            prop_emb_arr = self.prop_embed[prop_name](prop)
            prop_emb_arr[null_mask, :] = torch.randn(prop_emb_arr[null_mask, :].shape, device=prop_emb_arr.device)
            props.append(prop_emb_arr)
        return props

    def get_properties(self):
        """Get properties dictionary."""
        return self.properties

    def assemble_h(self, sample, element_emb, t, dt):
        """Assemble node features for GNN input from various sources.
        
        Combines timestep information, element embeddings, and property embeddings
        into node feature vectors for graph neural network processing.
        
        Args:
            sample (Sample): Material sample containing properties and batch indices.
            element_emb (torch.Tensor): Element embeddings, shape (n_atoms, embed_dim).
            t (torch.Tensor): Diffusion timesteps, shape (batch_size,).
            dt (torch.Tensor): Timestep differences, shape (batch_size,).
            
        Returns:
            torch.Tensor: Assembled node features, shape (n_atoms, total_feature_dim).
        """
        batch_idx = sample.get_batch_indices().to(sample.get_positions().device)
        if self.use_dt:
            h = torch.stack([t[batch_idx], dt[batch_idx]], -1)  # (n_atoms, 2)
        else:
            h = t[batch_idx].unsqueeze(-1)  # (n_atoms, 1)

        if self.element_embedding is not None:
            h = torch.concat((h, element_emb), dim=1)

        if self.properties is not None:
            props = self.embed_properties(sample)
            h = torch.concat([h] + props, dim=1)
        return h

    def forward(self, sample, t):
        """
        Forward pass to be implemented by subclasses.
        Should return (position_noise, element_noise).
        """
        raise NotImplementedError
