from typing import Optional

import torch as t

from auto_encoder import device
from auto_encoder.config import AutoEncoderConfig
from auto_encoder.config_enums import AutoEncoderType
from auto_encoder.models.base_ae import AutoEncoderBase
from auto_encoder.models.mutual_choice_ae import MutualChoiceSAE
from auto_encoder.models.topk_ae import TopKSAE
from auto_encoder.models.vanilla_ae import VanillaAE


def create_autoencoder(
    config: AutoEncoderConfig,
    medoid_initial_tensor_N: Optional[t.Tensor],
    scaling_factor: Optional[float],
    expert_initial_tensors: Optional[t.Tensor],
    autoencoder_type: AutoEncoderType = AutoEncoderType.VANILLA,
    device: str = device,
) -> AutoEncoderBase:
    if autoencoder_type == AutoEncoderType.VANILLA:
        return VanillaAE(
            config=config,
            medoid_initial_tensor_N=medoid_initial_tensor_N,
            preprocess_scaling_factor=scaling_factor,
            device=device,
        )
    elif autoencoder_type == AutoEncoderType.TOPK:
        return TopKSAE(
            config=config,
            medoid_initial_tensor_N=medoid_initial_tensor_N,
            preprocess_scaling_factor=scaling_factor,
            device=device,
        )
    elif autoencoder_type == AutoEncoderType.MUTUAL_CHOICE:
        return MutualChoiceSAE(
            config=config,
            medoid_initial_tensor_N=medoid_initial_tensor_N,
            preprocess_scaling_factor=scaling_factor,
            device=device,
        )
    else:
        raise ValueError(f"Unknown autoencoder type")
