from models.lip_convnets import LipConvNet
from models.lip_resnets import LipResNet


class BaseTrainer:
    def __init__(self, config):
        self.config = config

    def init_model(self):
        input_size = 32 if self.config.dataset in ["cifar10", "cifar100"] else 64
        if self.config.model == 'lipresnet':
            model = LipResNet(
                self.config.conv_layer,
                self.config.activation,
                init_channels=self.config.init_channels,
                input_size=input_size,
                block_size=self.config.block_size,
                num_classes=self.config.num_classes,
                lln=self.config.lln,
            )
        elif self.config.model == 'lipconvnet':
            model = LipConvNet(
                self.config.conv_layer,
                self.config.activation,
                init_channels=self.config.init_channels,
                block_size=self.config.block_size,
                num_classes=self.config.num_classes,
                lln=self.config.lln,
                kernel_size=self.config.kernel,
                num_dense=self.config.num_dense,
                mask_level=self.config.cphh_rank,
                input_size=input_size,
            )
        else:
            raise ValueError(f'Unknown model: {self.config.model}')

        return model


if __name__ == "__main__":
    pass
