from torch import nn
from typing import Type, Dict, Tuple

from models.config import MiMoEConfig


POSITION_EMBEDDING_REGISTRY: Dict[str, Type[nn.Module]] = {}
ROUTER_REGISTRY: Dict[str, Type[nn.Module]] = {}
BUFFER_REGISTRY: Dict[str, Type[nn.Module]] = {}
PRETRAIN_LOSS_REGISTRY: Dict[str, Type[nn.Module]] = {}
AUXILIARY_LOSS_REGISTRY: Dict[str, Type[nn.Module]] = {}
DOWNSTREAM_TASK_LOSS_REGISTRY: Dict[str, Type[nn.Module]] = {}
FEATURE_SELECTION_LOSS_REGISTRY: Dict[str, Type[nn.Module]] = {}
COMPACTNESS_LOSS_REGISTRY: Dict[str, Type[nn.Module]] = {}


def register_position_embedding(name: str):
    """
    Decorator to register a position embedding class in the POSITION_EMBEDDING_REGISTRY.
    """
    def decorator(cls: Type[nn.Module]):
        if name in POSITION_EMBEDDING_REGISTRY:
            raise ValueError(f"Position embedding '{name}' is already registered.")
        POSITION_EMBEDDING_REGISTRY[name] = cls
        return cls
    return decorator


def register_router(name: str):
    """
    Decorator to register a router class in the ROUTER_REGISTRY.
    """
    def decorator(cls: Type[nn.Module]):
        if name in ROUTER_REGISTRY:
            raise ValueError(f"Router '{name}' is already registered.")
        ROUTER_REGISTRY[name] = cls
        return cls
    return decorator


def register_buffer(name: str):
    """
    Decorator to register a buffer class in the BUFFER_REGISTRY.
    """
    def decorator(cls: Type[nn.Module]):
        if name in BUFFER_REGISTRY:
            raise ValueError(f"Buffer '{name}' is already registered.")
        BUFFER_REGISTRY[name] = cls
        return cls
    return decorator


def register_pretrain_loss(name: str):
    """
    Decorator to register a pretraining loss class in the PRETRAIN_LOSS_REGISTRY.
    """
    def decorator(cls: Type[nn.Module]):
        if name in PRETRAIN_LOSS_REGISTRY:
            raise ValueError(f"Pretrain loss '{name}' is already registered.")
        PRETRAIN_LOSS_REGISTRY[name] = cls
        return cls
    return decorator


def register_auxiliary_loss(name: str):
    """
    Decorator to register an auxiliary loss class in the AUXILIARY_LOSS_REGISTRY.
    """
    def decorator(cls: Type[nn.Module]):
        if name in AUXILIARY_LOSS_REGISTRY:
            raise ValueError(f"Auxiliary loss '{name}' is already registered.")
        AUXILIARY_LOSS_REGISTRY[name] = cls
        return cls
    return decorator


def register_downstream_task_loss(name: str):
    """
    Decorator to register a downstream task loss class in the DOWNSTREAM_TASK_LOSS_REGISTRY.
    """
    def decorator(cls: Type[nn.Module]):
        if name in DOWNSTREAM_TASK_LOSS_REGISTRY:
            raise ValueError(f"Downstream task loss '{name}' is already registered.")
        DOWNSTREAM_TASK_LOSS_REGISTRY[name] = cls
        return cls
    return decorator


def register_feature_selection_loss(name: str):
    """
    Decorator to register a feature selection loss class in the FEATURE_SELECTION_LOSS_REGISTRY.
    """
    def decorator(cls: Type[nn.Module]):
        if name in FEATURE_SELECTION_LOSS_REGISTRY:
            raise ValueError(f"Feature selection loss '{name}' is already registered.")
        FEATURE_SELECTION_LOSS_REGISTRY[name] = cls
        return cls
    return decorator


def register_compactness_loss(name: str):
    """
    Decorator to register a compactness loss class in the COMPACTNESS_LOSS_REGISTRY.
    """
    def decorator(cls: Type[nn.Module]):
        if name in COMPACTNESS_LOSS_REGISTRY:
            raise ValueError(f"Compactness loss '{name}' is already registered.")
        COMPACTNESS_LOSS_REGISTRY[name] = cls
        return cls
    return decorator


def get_position_embedding(config: MiMoEConfig) -> nn.Module:
    """
    Retrieve a position embedding class by name from the POSITION_EMBEDDING_REGISTRY.
    
    Args:
        name (str): The name of the position embedding to retrieve.
        config: Configuration object for the position embedding.
    
    Returns:
        nn.Module: An instance of the requested position embedding class.
    """
    name = config.position_embedding
    if name not in POSITION_EMBEDDING_REGISTRY:
        raise ValueError(f"Position embedding '{name}' is not registered.")
    return POSITION_EMBEDDING_REGISTRY[name](config)


def get_router(config: MiMoEConfig) -> nn.Module:
    """
    Retrieve a router class by name from the ROUTER_REGISTRY.
    
    Args:
        name (str): The name of the router to retrieve.
        config: Configuration object for the router.
    
    Returns:
        nn.Module: An instance of the requested router class.
    """
    name = config.router
    if name not in ROUTER_REGISTRY:
        raise ValueError(f"Router '{name}' is not registered.")
    return ROUTER_REGISTRY[name](config)


def get_buffer(config: MiMoEConfig) -> nn.Module:
    """
    Retrieve a buffer class by name from the BUFFER_REGISTRY.
    
    Args:
        name (str): The name of the buffer to retrieve.
        config: Configuration object for the buffer.
    
    Returns:
        nn.Module: An instance of the requested buffer class.
    """
    name = config.buffer
    if name not in BUFFER_REGISTRY:
        raise ValueError(f"Buffer '{name}' is not registered.")
    return BUFFER_REGISTRY[name](config)


def get_pretrain_loss(config: MiMoEConfig) -> nn.Module:
    """
    Retrieve a pretraining loss class by name from the PRETRAIN_LOSS_REGISTRY.
    
    Args:
        name (str): The name of the pretraining loss to retrieve.
        config: Configuration object for the pretraining loss.
    
    Returns:
        nn.Module: An instance of the requested pretraining loss class.
    """
    name = config.pretrain_loss
    if name not in PRETRAIN_LOSS_REGISTRY:
        raise ValueError(f"Pretrain loss '{name}' is not registered.")
    return PRETRAIN_LOSS_REGISTRY[name](config)


def get_auxiliary_losses(config: MiMoEConfig) -> Dict[str, Tuple[nn.Module, float]]:
    """
    Retrieve an auxiliary loss class by name from the AUXILIARY_LOSS_REGISTRY.
    
    Args:
        name (str): The name of the auxiliary loss to retrieve.
        config: Configuration object for the auxiliary loss.
    
    Returns:
        nn.Module: An instance of the requested auxiliary loss class.
    """
    target_losses = config.auxiliary_losses
    losses = {}
    for name, weight in target_losses:
        if name not in AUXILIARY_LOSS_REGISTRY:
            raise ValueError(f"Auxiliary loss '{name}' is not registered.")
        losses[name] = (AUXILIARY_LOSS_REGISTRY[name](config), weight)
    return losses


def get_downstream_task_loss(config: MiMoEConfig) -> nn.Module:
    """
    Retrieve a downstream task loss class by name from the DOWNSTREAM_TASK_LOSS_REGISTRY.
    
    Args:
        name (str): The name of the downstream task loss to retrieve.
        config: Configuration object for the downstream task loss.
    
    Returns:
        nn.Module: An instance of the requested downstream task loss class.
    """
    name = config.downstream_task_loss
    if name not in DOWNSTREAM_TASK_LOSS_REGISTRY:
        raise ValueError(f"Downstream task loss '{name}' is not registered.")
    return DOWNSTREAM_TASK_LOSS_REGISTRY[name](config)


def get_feature_selection_losses(config: MiMoEConfig) -> Dict[str, Tuple[nn.Module, float]]:
    """
    Retrieve a feature selection loss class by name from the FEATURE_SELECTION_LOSS_REGISTRY.
    
    Args:
        name (str): The name of the feature selection loss to retrieve.
        config: Configuration object for the feature selection loss.
    
    Returns:
        nn.Module: An instance of the requested feature selection loss class.
    """
    name, weight = config.feature_selection_losses
    losses = {}
    if name not in FEATURE_SELECTION_LOSS_REGISTRY:
        print(f"Feature selection loss '{name}' is not registered.")
        return losses
    losses[name] = (FEATURE_SELECTION_LOSS_REGISTRY[name](config).to(config.device), weight)
    return losses

def get_compactness_loss(config: MiMoEConfig) -> Dict[str, Tuple[nn.Module, float]]:
    """
    Retrieve a compactness loss class by name from the COMPACTNESS_LOSS_REGISTRY.
    
    Args:
        name (str): The name of the compactness loss to retrieve.
        config: Configuration object for the compactness loss.
    
    Returns:
        nn.Module: An instance of the requested compactness loss class.
    """
    name, weight = config.compactness_loss
    losses = {}
    if name not in COMPACTNESS_LOSS_REGISTRY:
        print(f"Compactness loss '{name}' is not registered.")
        return losses
    losses[name] = (COMPACTNESS_LOSS_REGISTRY[name](config), weight)
    return losses