import torch
from torch import nn


from transformers import RobertaConfig
from .postLayerNorm.tie_qk import RobertaModel as TieQKRobertaModel
from .postLayerNorm.relu import RobertaModel as ReluRobertaModel
from .postLayerNorm.softmax import RobertaModel as SoftmaxRobertaModel
from .preLayerNorm.relu import (
    RobertaPreLayerNormModel as ReluPreLNRobertaModel,
)
from .preLayerNorm.softmax import (
    RobertaPreLayerNormModel as SoftmaxPreLNRobertaModel,
)
from .preLayerNorm.tie_qk import (
    RobertaPreLayerNormModel as TieQKPreLNRobertaModel,
)
from .attnOnly.relu import (
    RobertaAttnOnlyModel as ReluAttnOnlyRobertaModel,
)
from .attnOnly.softmax import (
    RobertaAttnOnlyModel as SoftmaxAttnOnlyRobertaModel,
)
from .attnOnly.tie_qk import (
    RobertaAttnOnlyModel as TieQKAttnOnlyRobertaModel,
)


class RobertaModelForGraph(nn.Module):
    def __init__(
        self,
        num_nodes,
        num_attention_heads=1,
        hidden_size=128,
        num_layers=12,
        roberta_type="relu",  # softmax, relu, or tie_qk
        layer_norm_type="pre",  # pre or post
        attention_only=False,  # if True, use attention-only models (only supports layer_norm_type="pre")
    ):
        super(RobertaModelForGraph, self).__init__()

        # Validate attention_only constraints
        if attention_only and layer_norm_type != "pre":
            raise ValueError("attention_only=True only supports layer_norm_type='pre'")

        configuration = RobertaConfig(
            hidden_size=hidden_size,
            num_attention_heads=num_attention_heads,
            num_hidden_layers=num_layers,
            max_position_embeddings=num_nodes + 1,
        )

        self.name = f"{roberta_type}_roberta_embd={hidden_size}_layer={num_layers}_head={num_attention_heads}"
        if attention_only:
            self.name = f"{roberta_type}_attn_only_roberta_embd={hidden_size}_layer={num_layers}_head={num_attention_heads}"

        self.n_dims = num_nodes
        self.hidden_size = hidden_size
        self._read_in = nn.Linear(num_nodes, hidden_size)

        if attention_only:
            # Attention-only models (only pre-layer norm supported)
            if roberta_type == "tie_qk":
                self._backbone = TieQKAttnOnlyRobertaModel(
                    configuration, add_pooling_layer=False
                )
            elif roberta_type == "relu":
                self._backbone = ReluAttnOnlyRobertaModel(
                    configuration, add_pooling_layer=False
                )
            elif roberta_type == "softmax":
                self._backbone = SoftmaxAttnOnlyRobertaModel(
                    configuration, add_pooling_layer=False
                )
            else:
                raise ValueError("Invalid roberta_type")
        elif layer_norm_type == "post":
            if roberta_type == "tie_qk":
                self._backbone = TieQKRobertaModel(
                    configuration, add_pooling_layer=False
                )
            elif roberta_type == "relu":
                self._backbone = ReluRobertaModel(
                    configuration, add_pooling_layer=False
                )
            elif roberta_type == "softmax":
                self._backbone = SoftmaxRobertaModel(
                    configuration, add_pooling_layer=False
                )
            else:
                raise ValueError("Invalid roberta_type")
        elif layer_norm_type == "pre":
            if roberta_type == "tie_qk":
                self._backbone = TieQKPreLNRobertaModel(
                    configuration, add_pooling_layer=False
                )
            elif roberta_type == "relu":
                self._backbone = ReluPreLNRobertaModel(
                    configuration, add_pooling_layer=False
                )
            elif roberta_type == "softmax":
                self._backbone = SoftmaxPreLNRobertaModel(
                    configuration, add_pooling_layer=False
                )
            else:
                raise ValueError("Invalid roberta_type")
        else:
            raise ValueError(f"Unknown layer_norm_type: {self.layer_norm_type}")

        self.layer_norm_type = layer_norm_type
        self.attention_only = attention_only
        self._read_out = nn.Linear(hidden_size, num_nodes)

    def get_hidden_states(self, inputs):

        embeds = self._read_in(inputs)
        output = self._backbone(
            inputs_embeds=embeds,
            output_attentions=True,
            output_hidden_states=True,
        )
        hidden_states = output.hidden_states

        out_hs = []

        # Handle layer normalization based on the variant type
        if self.attention_only:
            # Attention-only models use pre-layer norm architecture
            try:
                # Try to get the layer norm from the first layer as reference
                layer_norm = self._backbone.encoder.layer[0].attention.output.LayerNorm
                for hs in hidden_states:
                    # Apply layer normalization before read-out for consistency
                    normalized_hs = layer_norm(hs)
                    out_hs.append(self._read_out(normalized_hs))
            except AttributeError:
                # If no LayerNorm found, apply final embeddings LayerNorm if available
                try:
                    embeddings_layer_norm = self._backbone.embeddings.LayerNorm
                    for hs in hidden_states:
                        normalized_hs = embeddings_layer_norm(hs)
                        out_hs.append(self._read_out(normalized_hs))
                except AttributeError:
                    # Fallback: use hidden states as-is
                    for hs in hidden_states:
                        out_hs.append(self._read_out(hs))
        elif self.layer_norm_type == "post":
            # Post-layer norm: LayerNorm is applied after attention/FFN, so hidden states are already normalized
            # Just need to check if final LayerNorm exists and apply it
            try:
                final_layer_norm = self._backbone.encoder.layer[-1].output.LayerNorm
                for hs in hidden_states:
                    # For post-layer norm, apply final layer norm if it exists
                    normalized_hs = (
                        final_layer_norm(hs) if final_layer_norm is not None else hs
                    )
                    out_hs.append(self._read_out(normalized_hs))
            except AttributeError:
                # If no LayerNorm found, just use hidden states as-is
                for hs in hidden_states:
                    out_hs.append(self._read_out(hs))

        elif self.layer_norm_type == "pre":
            # Pre-layer norm: LayerNorm is applied before attention/FFN
            # Hidden states from intermediate layers may need additional normalization
            try:
                # Try to get the layer norm from the first layer as reference
                layer_norm = self._backbone.encoder.layer[0].attention.output.LayerNorm
                for hs in hidden_states:
                    # Apply layer normalization before read-out for consistency
                    normalized_hs = layer_norm(hs)
                    out_hs.append(self._read_out(normalized_hs))
            except AttributeError:
                # If no LayerNorm found, apply final embeddings LayerNorm if available
                try:
                    embeddings_layer_norm = self._backbone.embeddings.LayerNorm
                    for hs in hidden_states:
                        normalized_hs = embeddings_layer_norm(hs)
                        out_hs.append(self._read_out(normalized_hs))
                except AttributeError:
                    # Fallback: use hidden states as-is
                    for hs in hidden_states:
                        out_hs.append(self._read_out(hs))
        else:
            raise ValueError(f"Unknown layer_norm_type: {self.layer_norm_type}")

        return out_hs

    def forward(self, inputs):
        """Standard forward pass used during training.

        We disable output_attentions/output_hidden_states here to avoid unnecessary
        allocation & retention of large intermediate tensors on each batch. If
        hidden states are needed for analysis, call get_hidden_states() instead.
        """
        embeds = self._read_in(inputs)
        # Do NOT request attentions / hidden states during regular training
        output = self._backbone(
            inputs_embeds=embeds,
            output_attentions=True,
            output_hidden_states=True,
        )
        prediction = self._read_out(output.last_hidden_state)
        return prediction


__all__ = ["RobertaModelForGraph"]
