
import numpy as np
import torch

from tabicl.config.config_pretrain import ConfigPretrain
from tabicl.config.config_run import ConfigRun
from tabicl.core.enums import ModelName
from tabicl.models.foundation.foundation_transformer import FoundationTransformer
from tabicl.models.foundation_flash.foundation_flash_transformer import FoundationFlashTransformer
from tabicl.models.ft_transformer.ft_transformer import FTTransformer
from tabicl.models.tab2d.tab2d import Tab2D
from tabicl.models.tabPFN.tabpfn_transformer import TabPFN


def get_model(cfg: ConfigRun, x_train: np.ndarray, y_train: np.ndarray, categorical_indicator: np.ndarray) -> torch.nn.Module:

    match cfg.model_name:
        case ModelName.FT_TRANSFORMER:
            return FTTransformer(cfg, x_train, y_train, categorical_indicator)
        case ModelName.AUTOGLUON:
            return torch.nn.Module()
        case ModelName.TABPFN:
            return TabPFN(
                cfg.hyperparams['use_pretrained_weights'], 
                path_to_weights=cfg.hyperparams['path_to_weights']
            )
        case ModelName.FOUNDATION:
            return FoundationTransformer(
                dim_model=cfg.hyperparams['dim_model'],
                dim_embedding=cfg.hyperparams['dim_embedding'],
                dim_output=cfg.hyperparams['dim_output'],
                n_layers=cfg.hyperparams['n_layers'],
                n_heads=cfg.hyperparams['n_heads'],
                y_as_float_embedding=cfg.hyperparams['y_as_float_embedding'],
                quantile_embedding=cfg.hyperparams['quantile_embedding_gpu'],
                feature_count_scaling=cfg.hyperparams['feature_count_scaling_gpu'],
                use_pretrained_weights=cfg.hyperparams['use_pretrained_weights'],
                path_to_weights=cfg.hyperparams['path_to_weights']
            )
        case ModelName.FOUNDATION_FLASH:
            return FoundationFlashTransformer(
                dim_model=cfg.hyperparams['dim_model'],
                dim_embedding=cfg.hyperparams['dim_embedding'],
                dim_output=cfg.hyperparams['dim_output'],
                n_layers=cfg.hyperparams['n_layers'],
                n_heads=cfg.hyperparams['n_heads'],
                y_as_float_embedding=cfg.hyperparams['y_as_float_embedding'],
                quantile_embedding=cfg.hyperparams['quantile_embedding_gpu'],
                feature_count_scaling=cfg.hyperparams['feature_count_scaling_gpu'],
                use_pretrained_weights=cfg.hyperparams['use_pretrained_weights'],
                path_to_weights=cfg.hyperparams['path_to_weights']
            )
        case ModelName.TAB2D:
            return Tab2D(
                n_features=cfg.hyperparams['n_features'],
                n_classes=cfg.hyperparams['n_classes'],
                dim=cfg.hyperparams['dim'],
                n_layers=cfg.hyperparams['n_layers'],
                n_heads=cfg.hyperparams['n_heads'],
                use_pretrained_weights=cfg.hyperparams['use_pretrained_weights'],
                path_to_weights=cfg.hyperparams['path_to_weights']
            )
        case _:
            raise NotImplementedError(f"Model {cfg.model_name} not implemented yet")
            
        


def get_model_pretrain(cfg: ConfigPretrain) -> torch.nn.Module:

    match cfg.model_name:
        case ModelName.TABPFN:
            return TabPFN(
                use_pretrained_weights=cfg.optim.use_pretrained_weights,
                path_to_weights=cfg.optim.path_to_weights
            )
        case ModelName.FOUNDATION:
            return FoundationTransformer(
                use_pretrained_weights=cfg.optim.use_pretrained_weights,
                path_to_weights=cfg.optim.path_to_weights,
                **cfg.model
            )
        case ModelName.FOUNDATION_FLASH:
            return FoundationFlashTransformer(
                use_pretrained_weights=cfg.optim.use_pretrained_weights,
                path_to_weights=cfg.optim.path_to_weights,
                **cfg.model
            )
        case ModelName.TAB2D:
            return Tab2D(
                n_features=cfg.data.max_features,
                n_classes=cfg.data.max_classes,
                use_pretrained_weights=cfg.optim.use_pretrained_weights,
                path_to_weights=cfg.optim.path_to_weights,
                **cfg.model
            )
        case _:
            raise NotImplementedError(f"Model {cfg.model['name']} not implemented yet")
