from typing import Optional
from loguru import logger
import numpy as np
from .base import VectorMapper
from .gating_moe import (
    FlatMoEMapper,
    HierarchicalMoEMapper,
    HierarchicalLoRAMoEMapper,
    GatingMechanism
)
from src.config.models import *

try:
    from .linear_mapper import LinearMapper
except ImportError:
    LinearMapper = None

try:
    from .procrustes_mapper import ProcrustesMapper
except ImportError:
    ProcrustesMapper = None

try:
    from .diffusion_mapper import DiffusionMapper
except ImportError:
    DiffusionMapper = None

try:
    from .spnt_mapper import SPNTMapper
except ImportError:
    SPNTMapper = None

try:
    from .spnt_diffusion_mapper import SPNTDiffusionMapper
except ImportError:
    SPNTDiffusionMapper = None

try:
    from .simple_la2m_mapper import SimpleLA2MMapper
except ImportError:
    SimpleLA2MMapper = None

try:
    from .simple_linear_mapper import SimpleLinearMapper
except ImportError:
    try:
        from .gating_moe.core.mlp import SimpleLinearMapper
    except ImportError:
        SimpleLinearMapper = None

try:
    from .embedding_converter import EmbeddingConverterMapper
except ImportError:
    EmbeddingConverterMapper = None

try:
    from .vec2vec import Vec2VecMapper
except ImportError:
    Vec2VecMapper = None
def create_gating_moe_mapper(config: GatingMoEConfig, *args, **kwargs):
    moe_type = config.moe_type
    mapper_config = config.mapper_config
    logger.info(f"Creating {moe_type} MoE Mapper")
    common_params = {
        'mapper_config': mapper_config,
        'clustering_method': config.clustering_method,
        'distance_metric': config.distance_metric,
        'random_state': config.random_state,
        'clustering_sample_size': config.clustering_sample_size,
    }
    if moe_type == 'flat':
        logger.info(f"Creating FlatMoEMapper with {config.num_experts} experts")
        return FlatMoEMapper(
            num_experts=config.num_experts,
            use_soft_routing=config.use_soft_routing,
            gating_temperature=config.gating_temperature,
            **common_params
        )
    elif moe_type == 'hierarchical':
        logger.info(f"Creating HierarchicalMoEMapper with {config.num_levels} levels, "
                   f"branch_factor={config.branch_factor}")
        return HierarchicalMoEMapper(
            num_levels=config.num_levels,
            branch_factor=config.branch_factor,
            **common_params
        )
    elif moe_type == 'hierarchical_lora':
        logger.info(f"Creating HierarchicalLoRAMoEMapper with {config.num_levels} levels, "
                   f"branch_factor={config.branch_factor}, lora_rank={config.lora_rank}")
        return HierarchicalLoRAMoEMapper(
            num_levels=config.num_levels,
            branch_factor=config.branch_factor,
            lora_rank=config.lora_rank,
            lora_alpha=config.lora_alpha,
            lora_dropout=config.lora_dropout,
            share_base_model=config.share_base_model,
            **common_params
        )
    else:
        raise ValueError(f"Unknown moe_type: {moe_type}. "
                        f"Must be one of: 'flat', 'hierarchical', 'hierarchical_lora'")
def create_gating_mechanism(config: GatingMoEConfig, centroids: np.ndarray) -> GatingMechanism:
    logger.info(f"Creating GatingMechanism with {len(centroids)} centroids")
    return GatingMechanism(
        centroids=centroids,
        distance_metric=config.distance_metric,
        temperature=config.gating_temperature,
        use_soft_routing=config.use_soft_routing
    )
def create_linear_mapper(config: LinearMapperConfig, input_dim: int, output_dim: int) -> LinearMapper:
    if LinearMapper is None:
        raise ImportError("LinearMapper is not available. Please ensure linear_mapper.py exists.")
    logger.info(f"Creating LinearMapper with {config.hidden_dim} hidden dimension")
    return LinearMapper(
        input_dim=input_dim,
        output_dim=output_dim,
        hidden_dim=config.hidden_dim,
        num_epochs=config.num_epochs,
        batch_size=config.batch_size
    )
def create_procrustes_mapper(config: ProcrustesMapperConfig) -> ProcrustesMapper:
    if ProcrustesMapper is None:
        raise ImportError("ProcrustesMapper is not available. Please ensure procrustes_mapper.py exists.")
    logger.info(f"Creating ProcrustesMapper with {config.approximate} approximate method")
    return ProcrustesMapper(
        approximate=config.approximate,
        q=config.q,
        with_rotation=config.with_rotation,
        use_norm=config.use_norm
    )
def create_diffusion_mapper(config: DiffusionMapperConfig, embedding_dim: int) -> DiffusionMapper:
    if DiffusionMapper is None:
        raise ImportError("DiffusionMapper is not available. Please ensure diffusion_mapper.py exists.")
    logger.info(f"Creating DiffusionMapper with {config.num_timesteps} timesteps")
    return DiffusionMapper(
        num_timesteps=config.num_timesteps,
        beta_start=config.beta_start,
        beta_end=config.beta_end,
        embedding_dim=embedding_dim,
        hidden_dim=config.hidden_dim,
        num_layers=config.num_layers,
        num_epochs=config.num_epochs,
        batch_size=config.batch_size,
        learning_rate=config.learning_rate,
    )
def create_spnt_mapper(config: SPNTConfig, input_dim: int, output_dim: int, linear_config: LinearMapperConfig) -> SPNTMapper:
    if SPNTMapper is None:
        raise ImportError("SPNTMapper is not available. Please ensure spnt_mapper.py exists.")
    logger.info(f"Creating SPNTMapper with {config.lambda_struct} lambda_struct")
    return SPNTMapper(
        input_dim=input_dim,
        hidden_dim=1024,
        output_dim=output_dim,
        num_layers=3,
        learning_rate=1e-4,
        num_epochs=linear_config.num_epochs,
        batch_size=linear_config.batch_size,
        lambda_struct=config.lambda_struct,
        k_neighbors=config.k_neighbors,
        num_anchors=config.num_anchors,
        num_projections=config.num_projections,
    )
def create_spnt_diffusion_mapper(config: SPNTConfig, diffusion_config: DiffusionMapperConfig, embedding_dim: int) -> SPNTDiffusionMapper:
    if SPNTDiffusionMapper is None:
        raise ImportError("SPNTDiffusionMapper is not available. Please ensure spnt_diffusion_mapper.py exists.")
    logger.info(f"Creating SPNTDiffusionMapper with {config.lambda_struct} lambda_struct")
    return SPNTDiffusionMapper(
        num_timesteps=diffusion_config.num_timesteps,
        beta_start=diffusion_config.beta_start,
        beta_end=diffusion_config.beta_end,
        embedding_dim=embedding_dim,
        hidden_dim=diffusion_config.hidden_dim,
        num_layers=diffusion_config.num_layers,
        num_epochs=diffusion_config.num_epochs,
        batch_size=diffusion_config.batch_size,
        lambda_struct=config.lambda_struct,
        struct_loss_type=config.struct_loss_type,
        k_neighbors=config.k_neighbors,
        num_anchors=config.num_anchors,
        num_projections=config.num_projections,
    )
def create_simple_la2m_mapper(config: LA2MMapperConfig, linear_config: LinearMapperConfig) -> SimpleLA2MMapper:
    if SimpleLA2MMapper is None:
        raise ImportError("SimpleLA2MMapper is not available. Please ensure simple_la2m_mapper.py exists.")
    logger.info(f"Creating SimpleLA2MMapper with {config.d_prime} d_prime")
    return SimpleLA2MMapper(
        d_prime=config.d_prime,
        pca_mapping=config.pca_mapping,
        pca_dim=config.pca_dim,
        use_norm=config.use_norm,
        device=config.device,
        batch_size=config.batch_size,
        verbose=config.verbose,
        hidden_dims=config.hidden_dims,
        alignment_strategy=config.alignment_strategy,
        learning_rate=config.learning_rate,
        num_epochs=config.num_epochs,
        loss_type=config.loss_type
    )
def create_simple_linear_mapper(config: SimpleLinearMapperConfig, input_dim: int, output_dim: int) -> SimpleLinearMapper:
    if SimpleLinearMapper is None:
        raise ImportError("SimpleLinearMapper is not available. Please ensure simple_linear_mapper.py exists or gating_moe/core/mlp.py is available.")
    logger.info(f"Creating SimpleLinearMapper with {config.learning_rate} learning rate")
    import torch
    device = torch.device(config.device) if config.device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return SimpleLinearMapper(
        input_dim=input_dim,
        output_dim=output_dim,
        device=device,
        learning_rate=config.learning_rate,
        num_epochs=config.num_epochs,
        batch_size=config.batch_size,
        gradient_clip=config.gradient_clip,
        weight_decay=config.weight_decay,
        scheduler_patience=config.scheduler_patience,
        scheduler_factor=config.scheduler_factor,
        early_stopping_patience=config.early_stopping_patience,
        min_delta=config.min_delta
    )
def create_embedding_converter_mapper(config: EmbeddingConverterMapperConfig, input_dim: int, output_dim: int) -> EmbeddingConverterMapper:
    if EmbeddingConverterMapper is None:
        raise ImportError("EmbeddingConverterMapper is not available. Please ensure embedding_converter.py exists.")
    logger.info(f"Creating EmbeddingConverterMapper with {config.hidden_dim} hidden dimension")
    return EmbeddingConverterMapper(
        input_dim=input_dim,
        output_dim=output_dim,
        device=config.device,
        hidden_multiplier=config.hidden_multiplier,
        num_hidden_layers=config.num_hidden_layers,
        activation=config.activation,
        out_l2_normalize=config.out_l2_normalize,
        dropout=config.dropout,
        learning_rate=config.learning_rate,
        num_epochs=config.num_epochs,
        batch_size=config.batch_size,
    )
def create_vec2vec_mapper(config: Vec2VecMapperConfig, input_dim: int, output_dim: int) -> Vec2VecMapper:
    if Vec2VecMapper is None:
        raise ImportError("Vec2VecMapper is not available. Please ensure vec2vec.py exists.")
    return Vec2VecMapper(
        input_dim=input_dim,
        output_dim=output_dim,
        config=config,
    )
def create_mapper(config: MapperConfig, input_dim: int, output_dim: int) -> VectorMapper:
    logger.info(f"Creating Mapper with {config.mapper_name} mapper")
    if config.mapper_name == "linear":
        return create_linear_mapper(config.linear, input_dim, output_dim)
    elif config.mapper_name == "procrustes":
        return create_procrustes_mapper(config.procrustes, input_dim, output_dim)
    elif config.mapper_name == "diffusion":
        return create_diffusion_mapper(config.diffusion, input_dim, output_dim)
    elif config.mapper_name == "spnt":
        return create_spnt_mapper(config.spnt, input_dim, output_dim)
    elif config.mapper_name == "simple_linear":
        return create_simple_linear_mapper(config.simple_linear, input_dim, output_dim)
    elif config.mapper_name == "gating-moe":
        return create_gating_moe_mapper(config.gating_moe, input_dim, output_dim)
    elif config.mapper_name == "emb_conv":
        return create_embedding_converter_mapper(config.emb_conv, input_dim, output_dim)
    elif config.mapper_name == "vec2vec":
        return create_vec2vec_mapper(config.vec2vec, input_dim, output_dim)
    else:
        raise ValueError(f"Unknown mapper_name: {config.mapper_name}. Must be one of: {config.supported_mappers}")
