from runners.pytorch_tabular_runner import PT_Runner
import os
from tqdm import tqdm
from pytorch_tabnet.metrics import Metric
import torch.nn as nn
from pytorch_tabular.models import CategoryEmbeddingModelConfig
from pytorch_tabular.config import OptimizerConfig
import pandas as pd
import logging
from types import SimpleNamespace
from typing import List, Dict, Any

class CategoryEmbeddingRunner(PT_Runner):
    
    def __init__(self, 
                config: SimpleNamespace, 
                data: pd.DataFrame, 
                labels: pd.Series, 
                numeric_cols: List[str], 
                category_cols: List[str],
                logger: logging.Logger, 
        ) -> None:
        super().__init__(config, data, labels, logger, numeric_cols, category_cols)

    def get_model_config(self, 
                        hparams: Dict[str, Any]
        ) -> CategoryEmbeddingModelConfig:

        model_config = CategoryEmbeddingModelConfig(
                task='classification',
                learning_rate=hparams['learning_rate'], 
                seed = self.random_seed,
                embedding_dropout = hparams['embedding_dropout'],
                layers = hparams['layers'],
                activation = hparams['activation'],
                # use_batch_norm = hparams['use_batch_norm'],
                metrics=[self.metric],
                metrics_params=[{'task' : self.config.data.task}]
                # metrics_params=[{'ignore_index' : 0}]
            )
        return model_config

    def get_optimizer_config(self, 
                            hparams: Dict[str, Any]
        ) -> OptimizerConfig:
        optimizer_config = OptimizerConfig(
            optimizer='Adam',
            optimizer_params={
                'weight_decay' : 0, 
                # 'learning_rate' : hparams['learning_rate']
            },
            lr_scheduler='StepLR',
            lr_scheduler_params={'step_size' : hparams['scheduler_step_size'],
                                 'gamma': hparams['scheduler_gamma']},
        )
        return optimizer_config



