"""
Factory for creating time series forecasting models with proper parameter mapping.
Handles the different parameter naming conventions across model architectures.
"""
from typing import Optional, Dict, Any
from pathlib import Path

from .spikernn import SpikeRNN
from .itransformer import iTransformer
from .spikformer import Spikformer
from .ispikeformer import iSpikeformer
from .timemixer import TimeMixer
from .dlinear import DLinear


def create_backbone_model(
    model_name: str,
    input_size: int,
    hidden_size: int = 64,
    max_length: int = 100,
    num_steps: int = 4,
    layers: int = 2,
    window: int = 96,
    horizon: int = 24,
    **kwargs
) -> Any:
    """
    Create a backbone model with proper parameter mapping.

    Args:
        model_name: Name of the model ('spikernn', 'itransformer', 'spikformer', 'ispikeformer', 'timemixer', 'dlinear')
        input_size: Size of input features
        hidden_size: Hidden dimension size
        max_length: Maximum sequence length
        num_steps: Number of time steps (for spike models)
        layers: Number of layers
        window: Input sequence length
        horizon: Forecast horizon
        **kwargs: Additional model-specific parameters

    Returns:
        Instantiated model
    """

    if model_name == "spikernn":
        return SpikeRNN(
            hidden_size=hidden_size,
            layers=layers,
            num_steps=num_steps,
            input_size=input_size,  # SpikeRNN uses input_size
            max_length=max_length,
            **kwargs
        )


    elif model_name == "itransformer":
        # Filter kwargs to only include parameters that iTransformer supports
        allowed_params = ['n_heads', 'd_ff', 'dropout', 'activation',
                         'output_attention', 'factor', 'embed', 'freq',
                         'class_strategy', 'weight_file']

        filtered_kwargs = {k: v for k, v in kwargs.items()
                          if k in allowed_params}

        return iTransformer(
            d_model=hidden_size,  # iTransformer uses d_model
            e_layers=layers,
            max_length=max_length,
            input_size=input_size,  # iTransformer uses input_size
            **filtered_kwargs
        )


    elif model_name == "spikformer":
        return Spikformer(
            dim=hidden_size,  # Spikformer uses dim as first parameter
            input_size=input_size,
            **kwargs
        )

    elif model_name == "ispikeformer":
        # iSpikeformer는 'heads' 파라미터를 사용 (not 'num_heads')
        ispikeformer_kwargs = {}

        # 파라미터 매핑: num_heads → heads
        if 'num_heads' in kwargs:
            ispikeformer_kwargs['heads'] = kwargs.pop('num_heads')

        # 불필요한 QAP 관련 파라미터 제거
        unwanted_keys = [
            'num_queries', 'use_side_channel', 'enable_time_features',
            'time_dim', 'F_client', 'd_qap', 'alignment_method'
        ]
        for key in unwanted_keys:
            kwargs.pop(key, None)

        return iSpikeformer(
            dim=hidden_size,  # iSpikeformer uses dim
            input_size=input_size,
            max_length=max_length,
            num_steps=num_steps,
            **ispikeformer_kwargs,
            **kwargs
        )

    elif model_name == "timemixer":
        return TimeMixer(
            input_size=input_size,
            seq_len=window,
            pred_len=horizon,
            d_model=hidden_size,
            e_layers=layers,
            channel_independence=False,
            **kwargs
        )

    elif model_name == "dlinear":
        return DLinear(
            input_size=input_size,
            seq_len=window,
            pred_len=horizon,
            **kwargs
        )

    else:
        raise ValueError(f"Unknown model name: {model_name}. "
                        f"Supported models: spikernn, itransformer, spikformer, ispikeformer, timemixer, dlinear")


def get_model_parameter_mapping(model_name: str) -> Dict[str, str]:
    """
    Get parameter name mappings for different models.

    Args:
        model_name: Name of the model

    Returns:
        Dictionary mapping common parameter names to model-specific names
    """
    mappings = {
        "spikernn": {
            "input_param": "input_size",
            "hidden_param": "hidden_size",
            "layers_param": "layers"
        },
        "itransformer": {
            "input_param": "input_size",
            "hidden_param": "d_model",
            "layers_param": "e_layers"
        },
        "spikformer": {
            "input_param": "input_dim",
            "hidden_param": "d_model",
            "layers_param": "e_layers"
        },
        "ispikeformer": {
            "input_param": "input_size",
            "hidden_param": "dim",
            "layers_param": "depths"
        },
        "timemixer": {
            "input_param": "input_size",
            "hidden_param": "d_model",
            "layers_param": "e_layers"
        },
        "dlinear": {
            "input_param": "input_size",
            "hidden_param": "hidden_size",  # DLinear doesn't use hidden_size but keeping for consistency
            "layers_param": "layers"
        }
    }

    return mappings.get(model_name, {})