from model.methods.base import Method

class FTTMethod(Method):
    def __init__(self, args, is_regression):
        super().__init__(args, is_regression)

    def construct_model(self, model_config = None):
        from model.models.ftt import Transformer
        if model_config is None:
            model_config = self.args.config['model']
        
        ple_mapping = getattr(self, 'shared_state', {}).get('ple_mapping', None)
        
        self.model = Transformer(
                d_numerical=self.d_in,
                categories=self.categories,
                d_out=self.d_out,
                ple_mapping=ple_mapping,
                **model_config
                ).to(self.args.device) 
        if self.args.use_float:
            self.model.float()
        else:
            self.model.double()

    