"""
Model factory for creating different types of neural network models

This module provides a unified interface for creating models for
irregular meshes (GNNs), regular grids (UNet, FNO), and various
encoder-decoder architectures.
"""

import torch
import torch.nn as nn
from typing import Union, Dict, Any, Optional
from enum import Enum

from .irregular_mesh import create_gnn_model, GNNPipeline
from .regular_grid import create_unet, create_fno, UNet, FourierNeuralOperator
from .components import create_encoder, create_decoder
from .base import Identity


class ModelType(Enum):
    """Enumeration of supported model types"""
    GNN = "gnn"
    UNET = "unet"
    FNO = "fno"
    ENCODER_DECODER = "encoder_decoder"


class ModelFactory:
    """
    Factory class for creating neural network models
    
    This factory supports creating three main types of models:
    1. Graph Neural Networks for irregular meshes
    2. UNet for regular grids 
    3. Fourier Neural Operator for regular grids
    """
    
    @staticmethod
    def create_model(model_type: Union[str, ModelType], 
                    num_features: int, 
                    num_classes: int,
                    config: Optional[object] = None,
                    **kwargs) -> nn.Module:
        """
        Create a model of the specified type
        
        Parameters:
        -----------
            model_type: str or ModelType
                Type of model to create ("gnn", "unet", "fno", "encoder_decoder")
            num_features: int
                Number of input features
            num_classes: int
                Number of output classes
            config: object, optional
                Configuration object with model parameters
            **kwargs: dict
                Additional keyword arguments for model creation
                
        Returns:
        --------
            nn.Module
                Created neural network model
        """
        if isinstance(model_type, str):
            model_type = ModelType(model_type.lower())
        
        if model_type == ModelType.GNN:
            return ModelFactory._create_gnn_model(num_features, num_classes, config, **kwargs)
        elif model_type == ModelType.UNET:
            return ModelFactory._create_unet_model(num_features, num_classes, config, **kwargs)
        elif model_type == ModelType.FNO:
            return ModelFactory._create_fno_model(num_features, num_classes, config, **kwargs)
        elif model_type == ModelType.ENCODER_DECODER:
            return ModelFactory._create_encoder_decoder_model(num_features, num_classes, config, **kwargs)
        else:
            raise ValueError(f"Unknown model type: {model_type}")
    
    @staticmethod
    def _create_gnn_model(num_features: int, num_classes: int, 
                         config: Optional[object] = None, **kwargs) -> GNNPipeline:
        """Create a Graph Neural Network model for irregular meshes"""
        
        if config is not None:
            return create_gnn_model(num_features, num_classes, config)
        
        # Use default parameters if no config provided
        from types import SimpleNamespace
        default_config = SimpleNamespace(
            encoder=kwargs.get('encoder', 'mlp'),
            gnn=kwargs.get('gnn', 'sage'),
            decoder=kwargs.get('decoder', 'mlp'),
            n_hidden=kwargs.get('n_hidden', 64),
            n_layers=kwargs.get('n_layers', 3),
            activation=kwargs.get('activation', 'relu'),
            encoder_n_layers=kwargs.get('encoder_n_layers', 3),
            decoder_n_layers=kwargs.get('decoder_n_layers', 3),
            encoder_frequency=kwargs.get('encoder_frequency', 1),
            decoder_frequency=kwargs.get('decoder_frequency', 1),
            window_size=kwargs.get('window_size', 4),
            encoder_use_bn=kwargs.get('encoder_use_bn', False),
            encoder_use_res=kwargs.get('encoder_use_res', False),
            decoder_use_bn=kwargs.get('decoder_use_bn', False),
            decoder_use_res=kwargs.get('decoder_use_res', False),
            num_heads=kwargs.get('num_heads', 4),
            num_hops=kwargs.get('num_hops', 3),
            dropout=kwargs.get('dropout', 0.),
            use_input_norm=kwargs.get('use_input_norm', True)
        )
        
        return create_gnn_model(num_features, num_classes, default_config)
    
    @staticmethod
    def _create_unet_model(num_features: int, num_classes: int,
                          config: Optional[object] = None, **kwargs) -> UNet:
        """Create a UNet model for regular grids"""
        
        unet_params = {
            'features': kwargs.get('features', [64, 128, 256, 512]),
            'activation': kwargs.get('activation', 'relu'),
            'use_batchnorm': kwargs.get('use_batchnorm', True),
            'dropout': kwargs.get('dropout', 0.0)
        }
        
        if config is not None:
            # Override with config values if available
            unet_params.update({
                'features': getattr(config, 'unet_features', unet_params['features']),
                'activation': getattr(config, 'activation', unet_params['activation']),
                'use_batchnorm': getattr(config, 'use_batchnorm', unet_params['use_batchnorm']),
                'dropout': getattr(config, 'dropout', unet_params['dropout'])
            })
        
        return create_unet(num_features, num_classes, **unet_params)
    
    @staticmethod
    def _create_fno_model(num_features: int, num_classes: int,
                         config: Optional[object] = None, **kwargs) -> FourierNeuralOperator:
        """Create a Fourier Neural Operator model for regular grids"""
        
        fno_params = {
            'modes1': kwargs.get('modes1', 12),
            'modes2': kwargs.get('modes2', 12),
            'width': kwargs.get('width', 64),
            'num_layers': kwargs.get('num_layers', 4),
            'activation': kwargs.get('activation', 'gelu'),
            'padding': kwargs.get('padding', 8)
        }
        
        if config is not None:
            # Override with config values if available
            fno_params.update({
                'modes1': getattr(config, 'fno_modes1', fno_params['modes1']),
                'modes2': getattr(config, 'fno_modes2', fno_params['modes2']),
                'width': getattr(config, 'n_hidden', fno_params['width']),
                'num_layers': getattr(config, 'n_layers', fno_params['num_layers']),
                'activation': getattr(config, 'activation', fno_params['activation']),
                'padding': getattr(config, 'fno_padding', fno_params['padding'])
            })
        
        return create_fno(num_features, num_classes, **fno_params)
    
    @staticmethod
    def _create_encoder_decoder_model(num_features: int, num_classes: int,
                                     config: Optional[object] = None, **kwargs) -> nn.Module:
        """Create a simple encoder-decoder model"""
        
        # Default parameters
        encoder_type = kwargs.get('encoder', 'mlp')
        decoder_type = kwargs.get('decoder', 'mlp')
        hidden_dim = kwargs.get('n_hidden', 64)
        
        if config is not None:
            encoder_type = getattr(config, 'encoder', encoder_type)
            decoder_type = getattr(config, 'decoder', decoder_type)
            hidden_dim = getattr(config, 'n_hidden', hidden_dim)
        
        # Create encoder
        encoder_params = {
            'num_hidden': hidden_dim,
            'num_layers': kwargs.get('encoder_n_layers', 3),
            'activation': kwargs.get('activation', 'relu')
        }
        
        encoder = create_encoder(encoder_type, num_features, hidden_dim, **encoder_params)
        
        # Create decoder
        decoder_params = {
            'num_hidden': hidden_dim,
            'num_layers': kwargs.get('decoder_n_layers', 3),
            'activation': kwargs.get('activation', 'relu')
        }
        
        decoder = create_decoder(decoder_type, hidden_dim, num_classes, **decoder_params)
        
        # Create simple pipeline
        return EncoderDecoderPipeline(encoder, decoder)
    
    @staticmethod
    def get_supported_models() -> Dict[str, Dict[str, Any]]:
        """
        Get information about supported model types and their parameters
        
        Returns:
        --------
            dict
                Dictionary containing model types and their parameter descriptions
        """
        return {
            'gnn': {
                'description': 'Graph Neural Networks for irregular meshes',
                'supported_gnns': ['gcn', 'gat', 'sage', 'sign', 'mpnp'],
                'supported_encoders': ['identity', 'mlp', 'freq', 'lstm', 'gru', 'rnn'],
                'supported_decoders': ['identity', 'mlp', 'freq', 'conv1d'],
                'parameters': {
                    'n_hidden': 'Hidden dimension size',
                    'n_layers': 'Number of GNN layers',
                    'activation': 'Activation function',
                    'encoder': 'Type of encoder',
                    'decoder': 'Type of decoder',
                    'gnn': 'Type of GNN processor'
                }
            },
            'unet': {
                'description': 'UNet for regular grids',
                'parameters': {
                    'features': 'List of feature sizes for each level',
                    'activation': 'Activation function',
                    'use_batchnorm': 'Whether to use batch normalization',
                    'dropout': 'Dropout rate'
                }
            },
            'fno': {
                'description': 'Fourier Neural Operator for regular grids',
                'parameters': {
                    'modes1': 'Number of Fourier modes in first dimension',
                    'modes2': 'Number of Fourier modes in second dimension',
                    'width': 'Hidden channel width',
                    'num_layers': 'Number of FNO layers',
                    'activation': 'Activation function',
                    'padding': 'Padding size'
                }
            },
            'encoder_decoder': {
                'description': 'Simple encoder-decoder architecture',
                'parameters': {
                    'encoder': 'Type of encoder',
                    'decoder': 'Type of decoder',
                    'n_hidden': 'Hidden dimension size'
                }
            }
        }


class EncoderDecoderPipeline(nn.Module):
    """Simple encoder-decoder pipeline"""
    
    def __init__(self, encoder: nn.Module, decoder: nn.Module):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, x, *args, **kwargs):
        """
        Parameters:
        -----------
            x: torch.FloatTensor [..., n_node, n_feature]
        Returns:
        --------
            y: torch.FloatTensor [..., n_node, n_class]
        """
        x = self.encoder(x)
        x = self.decoder(x)
        return x


# Convenience functions for backward compatibility
def init_model(num_features: int, num_classes: int, config: object) -> nn.Module:
    """
    Legacy function for creating models
    
    This function provides backward compatibility with the old model creation interface.
    It attempts to determine the model type from the config and creates the appropriate model.
    """
    # Try to determine model type from config
    if hasattr(config, 'model_type'):
        model_type = config.model_type
    elif hasattr(config, 'use_regular_grid') and config.use_regular_grid:
        if hasattr(config, 'use_fno') and config.use_fno:
            model_type = 'fno'
        else:
            model_type = 'unet'
    else:
        model_type = 'gnn'  # Default to GNN for irregular meshes
    
    return ModelFactory.create_model(model_type, num_features, num_classes, config)