from torch import nn

from .encoder import OSRTEncoder, BSAEncoder
from .decoder import SlotMixerDecoder

from . import layers


class OSRT(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        encoder_type = cfg['encoder']
        decoder_type = cfg['decoder']

        layers.__USE_DEFAULT_INIT__ = cfg.get('use_default_init', False)
        
        if encoder_type == 'osrt':
            self.encoder = OSRTEncoder(**cfg['encoder_kwargs'])
        elif encoder_type == '3ddir':
            self.encoder = BSAEncoder(**cfg['encoder_kwargs'])                        
        else:
            raise ValueError(f'Unknown encoder type: {encoder_type}')

        if decoder_type == 'slot_mixer':
            self.decoder = SlotMixerDecoder(**cfg['decoder_kwargs'])
        else:
            raise ValueError(f'Unknown decoder type: {decoder_type}')
