from runners.pytorch_tabular_runner import PT_Runner
import os
from tqdm import tqdm
from pytorch_tabnet.metrics import Metric
import torch.nn as nn
from pytorch_tabular.models import FTTransformerConfig
from pytorch_tabular.config import OptimizerConfig
import pandas as pd
import logging
from types import SimpleNamespace
from typing import List, Dict, Any
class FTTransformerRunner(PT_Runner):
    
    def __init__(self, 
                config: SimpleNamespace, 
                data: pd.DataFrame, 
                labels: pd.Series, 
                numeric_cols: List[str], 
                category_cols: List[str],
                logger: logging.Logger
        ) -> None:
        super().__init__(config, data, labels, logger, numeric_cols, category_cols)

    def get_model_config(self, 
                        hparams: Dict[str, Any]
        ) -> FTTransformerConfig:
        model_config = FTTransformerConfig(
                task='classification',
                learning_rate=hparams['learning_rate'], 
                seed = self.random_seed,
                input_embed_dim = hparams['input_embed_dim'],
                embedding_dropout=hparams['embedding_dropout'], 
                share_embedding=hparams['share_embedding'],
                num_heads=hparams['num_heads'],
                num_attn_blocks = hparams['num_attn_blocks'],
                transformer_activation = hparams['transformer_activation'],
                batch_norm_continuous_input = hparams['batch_norm_continuous_input'],
                metrics=[self.metric],
                metrics_params=[{'task' : self.config.data.task}]
                # metrics_params=[{'ignore_index' : 0}],
            )
        return model_config

    def get_optimizer_config(self, 
                            hparams: Dict[str, Any]
        ) -> OptimizerConfig:
        optimizer_config = OptimizerConfig(
            optimizer='Adam',
            optimizer_params={
                'weight_decay' : 0, 
                # 'learning_rate' : hparams['learning_rate']
                # 'lr' : hparams['learning_rate']
            },
            lr_scheduler='StepLR',
            lr_scheduler_params={'step_size' : hparams['scheduler_step_size'],
                                 'gamma': hparams['scheduler_gamma']},
        )
        return optimizer_config

