"""
Unified Models Interface for Transformer-Graph Models

This module provides a unified interface to access all model architectures
with consistent parameters and configuration options.
"""

import torch
from torch import nn
from typing import Dict, Any, Optional, Union, List
from abc import ABC, abstractmethod

# Import model classes
from .roberta.robertaModels import RobertaModelForGraph
from .loopedTransformer.loopedTransformerModel import LoopedTransformer
from .disentangledTransformer.disentangledTransformerModels import (
    DisentangledTransformer,
)


class BaseGraphModel(ABC):
    """Base class for all graph models with standardized interface"""

    @abstractmethod
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the model"""
        pass

    @abstractmethod
    def get_hidden_states(self, x: torch.Tensor) -> List[torch.Tensor]:
        """Get intermediate hidden states from the model"""
        pass


class RobertaGraphModel(BaseGraphModel):
    """Wrapper for RobertaModelForGraph with standardized interface"""

    def __init__(
        self,
        num_nodes: int,
        num_attention_heads: int = 1,
        hidden_size: int = 128,
        num_layers: int = 12,
        roberta_type: str = "relu",
        layer_norm_type: str = "pre",
        **kwargs,
    ):
        self.model = RobertaModelForGraph(
            num_nodes=num_nodes,
            num_attention_heads=num_attention_heads,
            hidden_size=hidden_size,
            num_layers=num_layers,
            roberta_type=roberta_type,
            layer_norm_type=layer_norm_type,
        )
        self.num_nodes = num_nodes
        self.hidden_size = hidden_size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def get_hidden_states(self, x: torch.Tensor) -> List[torch.Tensor]:
        return self.model.get_hidden_states(x)


class LoopedTransformerModel(BaseGraphModel):
    """Wrapper for LoopedTransformer with standardized interface"""

    def __init__(
        self,
        num_nodes: int,
        hidden_size: int = 128,
        num_layers: int = 5,
        num_attention_heads: int = 1,
        read_in_method: str = "linear",
        layer_norm_type: str = "pre",
        tie_qk: bool = False,
        **kwargs,
    ):
        self.model = LoopedTransformer(
            num_nodes=num_nodes,
            num_layers=num_layers,
            hidden_size=hidden_size,
            read_in_method=read_in_method,
            layer_norm_type=layer_norm_type,
            num_attention_heads=num_attention_heads,
            tie_qk=tie_qk,
        )
        self.num_nodes = num_nodes
        self.hidden_size = hidden_size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        output, _ = self.model(x)
        return output

    def get_hidden_states(self, x: torch.Tensor) -> List[torch.Tensor]:
        """Get hidden states from all layers of the looped transformer"""
        _, layer_outputs = self.model(x)
        return layer_outputs

    def get_hidden_states(self, x: torch.Tensor) -> List[torch.Tensor]:
        _, layer_outputs = self.model(x)
        return layer_outputs


class DisentangledTransformerModel(BaseGraphModel):
    """Wrapper for DisentangledTransformer with standardized interface"""

    def __init__(
        self,
        num_nodes: int,
        heads: List[int] = None,
        extra_pos_id: bool = True,
        init_type: str = "randn",
        readout_type: str = "linear",
        **kwargs,
    ):
        if heads is None:
            heads = [4, 4, 4]  # Default architecture

        self.model = DisentangledTransformer(
            num_nodes=num_nodes,
            heads=heads,
            extra_pos_id=extra_pos_id,
            init_type=init_type,
            readout_type=readout_type,
        )
        self.num_nodes = num_nodes
        self.heads = heads

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def get_hidden_states(self, x: torch.Tensor) -> List[torch.Tensor]:
        return self.model.get_hidden_states(x)


class UnifiedModelRegistry:
    """Registry for all available models with unified configuration"""

    MODEL_MAP = {
        "roberta": RobertaGraphModel,
        "looped_transformer": LoopedTransformerModel,
        "disentangled_transformer": DisentangledTransformerModel,
    }

    @classmethod
    def list_models(cls) -> List[str]:
        """List all available model types"""
        return list(cls.MODEL_MAP.keys())

    @classmethod
    def create_model(cls, model_type: str, **kwargs) -> BaseGraphModel:
        """
        Create a model of the specified type with unified parameters.

        Args:
            model_type (str): Type of model to create. Options:
                - 'roberta': RoBERTa-based graph model
                - 'looped_transformer': Looped transformer model
                - 'disentangled_transformer': Disentangled attention model
            **kwargs: Model-specific parameters

        Common parameters:
            num_nodes (int): Number of nodes in the graph
            hidden_size (int): Hidden dimension size (default: 128)
            num_attention_heads (int): Number of attention heads (default: 1)

        Model-specific parameters:
            RoBERTa:
                - num_layers (int): Number of transformer layers (default: 12)
                - roberta_type (str): Variant type - "relu", "softmax", "tie_qk" (default: "relu")
                - layer_norm_type (str): Layer norm type - "pre" or "post" (default: "pre")

            Looped Transformer:
                - num_layers (int): Number of loops/iterations (default: 5)
                - read_in_method (str): Input method - "linear" or "zero_pad" (default: "linear")
                - layer_norm_type (str): Layer norm type - "pre" or "post" (default: "pre")
                - tie_qk (bool): Whether to tie query and key weights (default: False)

            Disentangled Transformer:
                - heads (List[int]): Number of heads per layer (default: [4, 4, 4])
                - extra_pos_id (bool): Whether to add positional embeddings (default: True)
                - init_type (str): Weight initialization - "randn", "zeros", "eye", "psd", "sym" (default: "randn")
                - readout_type (str): Readout method - "linear", "sum", "last" (default: "linear")

        Returns:
            BaseGraphModel: Instantiated model

        Example:
            # Create a RoBERTa model
            model = UnifiedModelRegistry.create_model(
                'roberta',
                num_nodes=32,
                hidden_size=128,
                num_layers=6,
                roberta_type='relu'
            )

            # Create a looped transformer
            model = UnifiedModelRegistry.create_model(
                'looped_transformer',
                num_nodes=32,
                hidden_size=64,
                num_layers=3
            )
        """
        if model_type not in cls.MODEL_MAP:
            raise ValueError(
                f"Unknown model_type: {model_type}. "
                f"Available types: {cls.list_models()}"
            )

        # Validate required parameters
        if "num_nodes" not in kwargs:
            raise ValueError("Required parameter 'num_nodes' not provided")

        model_class = cls.MODEL_MAP[model_type]
        return model_class(**kwargs)

    @classmethod
    def get_model_info(cls, model_type: str) -> Dict[str, Any]:
        """Get information about a specific model type"""
        if model_type not in cls.MODEL_MAP:
            raise ValueError(f"Unknown model_type: {model_type}")

        model_class = cls.MODEL_MAP[model_type]

        info = {
            "model_type": model_type,
            "class_name": model_class.__name__,
            "description": model_class.__doc__ or "No description available",
        }

        # Add model-specific parameter information
        if model_type == "roberta":
            info["parameters"] = {
                "num_nodes": "Number of nodes in the graph",
                "hidden_size": "Hidden dimension size (default: 128)",
                "num_attention_heads": "Number of attention heads (default: 1)",
                "num_layers": "Number of transformer layers (default: 12)",
                "roberta_type": 'Variant type - "relu", "softmax", "tie_qk" (default: "relu")',
                "layer_norm_type": 'Layer norm type - "pre" or "post" (default: "pre")',
            }
        elif model_type == "looped_transformer":
            info["parameters"] = {
                "num_nodes": "Number of nodes in the graph",
                "hidden_size": "Hidden dimension size (default: 128)",
                "num_layers": "Number of loops (default: 5)",
                "num_attention_heads": "Number of attention heads (default: 1)",
                "read_in_method": 'Input method - "linear" or "zero_pad" (default: "linear")',
                "layer_norm_type": 'Layer norm type - "pre" or "post" (default: "pre")',
                "tie_qk": "Whether to tie query and key weights (default: False)",
            }
        elif model_type == "disentangled_transformer":
            info["parameters"] = {
                "num_nodes": "Number of nodes in the graph",
                "heads": "Number of heads per layer (default: [4, 4, 4])",
                "extra_pos_id": "Whether to add positional embeddings (default: True)",
                "init_type": 'Weight initialization - "randn", "zeros", "eye", "psd", "sym" (default: "randn")',
                "readout_type": 'Readout method - "linear", "sum", "last" (default: "linear")',
            }

        return info


# Convenience functions
def create_model(model_type: str, **kwargs) -> BaseGraphModel:
    """
    Convenience function to create a model.

    Example usage:
        # RoBERTa model
        model = create_model('roberta', num_nodes=32, hidden_size=128)

        # Looped transformer
        model = create_model('looped_transformer', num_nodes=32, num_layers=5)

        # Disentangled transformer
        model = create_model('disentangled_transformer', num_nodes=32, heads=[8, 8])
    """
    return UnifiedModelRegistry.create_model(model_type, **kwargs)


def list_models() -> List[str]:
    """List all available model types"""
    return UnifiedModelRegistry.list_models()


def get_model_info(model_type: str) -> Dict[str, Any]:
    """Get information about a specific model type"""
    return UnifiedModelRegistry.get_model_info(model_type)


# Legacy compatibility aliases
def create_roberta_model(**kwargs) -> RobertaGraphModel:
    """Legacy function for creating RoBERTa models"""
    return create_model("roberta", **kwargs)


def create_looped_transformer(**kwargs) -> LoopedTransformerModel:
    """Legacy function for creating looped transformer models"""
    return create_model("looped_transformer", **kwargs)


def create_disentangled_transformer(**kwargs) -> DisentangledTransformerModel:
    """Legacy function for creating disentangled transformer models"""
    return create_model("disentangled_transformer", **kwargs)


if __name__ == "__main__":
    # Example usage demonstrations
    print("=== Unified Model Interface Examples ===\n")

    # List available models
    print("Available models:", list_models())

    # Example 1: Create RoBERTa model
    print("\n1. Creating RoBERTa model...")
    roberta_model = create_model(
        "roberta", num_nodes=32, hidden_size=128, num_layers=6, roberta_type="relu"
    )
    print(f"   Model type: {type(roberta_model.model).__name__}")

    # Example 2: Create Looped Transformer
    print("\n2. Creating Looped Transformer...")
    looped_model = create_model(
        "looped_transformer", num_nodes=32, hidden_size=64, num_layers=3
    )
    print(f"   Model type: {type(looped_model.model).__name__}")

    # Example 3: Create Disentangled Transformer
    print("\n3. Creating Disentangled Transformer...")
    disentangled_model = create_model(
        "disentangled_transformer", num_nodes=32, heads=[4, 4], init_type="randn"
    )
    print(f"   Model type: {type(disentangled_model.model).__name__}")

    # Example 4: Test forward pass
    print("\n4. Testing forward pass...")
    x = torch.randn(2, 32, 32)  # Batch size 2, 32x32 adjacency matrices

    output1 = roberta_model.forward(x)
    output2 = looped_model.forward(x)
    output3 = disentangled_model.forward(x)

    print(f"   RoBERTa output shape: {output1.shape}")
    print(f"   Looped output shape: {output2.shape}")
    print(f"   Disentangled output shape: {output3.shape}")

    # Example 5: Get model info
    print("\n5. Model information:")
    for model_type in list_models():
        info = get_model_info(model_type)
        print(f"   {model_type}: {info['description']}")
